From e4cae4490e536fea4235a8b2a0612129342b155e Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Thu, 7 May 2026 01:03:50 -0700 Subject: [PATCH 1/4] perf: Avoid creating processing buffers beyond what is needed. In Dart, processing buffers are sliced up from the merge buffer. For stages that do not use all processing threads -- perhaps because they do not have enough inputs -- we can be more efficient with memory by slicing the merge buffer based on the actual number of processors, not the number of processing threads. This patch addresses it by deferring the choice of how many buffers are needed until the stage actually starts executing. At that point, it knows how many processors it will create. --- .../msq/dart/worker/DartFrameContext.java | 42 +++++- .../worker/DartProcessingBuffersProvider.java | 59 ++++++-- .../msq/dart/worker/DartWorkerContext.java | 3 +- .../apache/druid/msq/exec/FrameContext.java | 11 ++ .../druid/msq/exec/ProcessingBuffersSet.java | 103 ++++++++----- .../msq/indexing/IndexerFrameContext.java | 38 ++++- .../msq/indexing/IndexerWorkerContext.java | 2 +- .../msq/querykit/BaseLeafStageProcessor.java | 4 + .../DartProcessingBuffersProviderTest.java | 142 ++++++++++++++---- .../msq/exec/ProcessingBuffersSetTest.java | 74 +-------- .../druid/msq/test/MSQTestWorkerContext.java | 6 + 11 files changed, 322 insertions(+), 162 deletions(-) diff --git a/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartFrameContext.java b/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartFrameContext.java index d17a94754ab7..fca50d51efe0 100644 --- a/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartFrameContext.java +++ b/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartFrameContext.java @@ -27,11 +27,11 @@ import org.apache.druid.msq.exec.FrameContext; import org.apache.druid.msq.exec.FrameWriterSpec; import org.apache.druid.msq.exec.ProcessingBuffers; +import org.apache.druid.msq.exec.ProcessingBuffersSet; import org.apache.druid.msq.exec.WorkerContext; import org.apache.druid.msq.exec.WorkerMemoryParameters; import org.apache.druid.msq.exec.WorkerStorageParameters; import org.apache.druid.msq.kernel.StageId; -import org.apache.druid.query.groupby.GroupingEngine; import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.query.rowsandcols.serde.WireTransferableContext; import org.apache.druid.segment.IndexIO; @@ -42,6 +42,7 @@ import org.apache.druid.segment.loading.DataSegmentPusher; import org.apache.druid.server.SegmentManager; +import javax.annotation.Nullable; import java.io.File; /** @@ -52,24 +53,33 @@ public class DartFrameContext implements FrameContext private final StageId stageId; private final FrameWriterSpec frameWriterSpec; private final SegmentWrangler segmentWrangler; - private final GroupingEngine groupingEngine; private final SegmentManager segmentManager; private final CoordinatorClient coordinatorClient; private final WorkerContext workerContext; - private final ResourceHolder processingBuffers; + + /** + * Null if the stage does not use processing buffers. + */ + @Nullable + private final ProcessingBuffersSet processingBuffersSet; private final WorkerMemoryParameters memoryParameters; private final WorkerStorageParameters storageParameters; private final DataServerQueryHandlerFactory dataServerQueryHandlerFactory; + /** + * Acquired by {@link #acquireProcessingBuffers}. + */ + @Nullable + private ResourceHolder processingBuffers; + public DartFrameContext( final StageId stageId, final WorkerContext workerContext, final FrameWriterSpec frameWriterSpec, final SegmentWrangler segmentWrangler, - final GroupingEngine groupingEngine, final SegmentManager segmentManager, final CoordinatorClient coordinatorClient, - final ResourceHolder processingBuffers, + @Nullable final ProcessingBuffersSet processingBuffersSet, final WorkerMemoryParameters memoryParameters, final WorkerStorageParameters storageParameters, final DataServerQueryHandlerFactory dataServerQueryHandlerFactory @@ -78,11 +88,10 @@ public DartFrameContext( this.stageId = stageId; this.segmentWrangler = segmentWrangler; this.frameWriterSpec = frameWriterSpec; - this.groupingEngine = groupingEngine; this.segmentManager = segmentManager; this.coordinatorClient = coordinatorClient; this.workerContext = workerContext; - this.processingBuffers = processingBuffers; + this.processingBuffersSet = processingBuffersSet; this.memoryParameters = memoryParameters; this.storageParameters = storageParameters; this.dataServerQueryHandlerFactory = dataServerQueryHandlerFactory; @@ -160,9 +169,24 @@ public IndexMerger indexMerger() throw DruidException.defensive("Ingestion not implemented"); } + @Override + public void acquireProcessingBuffers(final int requestedSlices) + { + if (processingBuffersSet == null) { + throw DruidException.defensive("Stage[%s] does not use processing buffers", stageId); + } + if (processingBuffers != null) { + throw DruidException.defensive("Processing buffers already acquired"); + } + processingBuffers = processingBuffersSet.acquire(requestedSlices); + } + @Override public ProcessingBuffers processingBuffers() { + if (processingBuffers == null) { + throw DruidException.defensive("Processing buffers not yet acquired"); + } return processingBuffers.get(); } @@ -193,6 +217,8 @@ public FrameWriterSpec frameWriterSpec() @Override public void close() { - processingBuffers.close(); + if (processingBuffers != null) { + processingBuffers.close(); + } } } diff --git a/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProvider.java b/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProvider.java index 533f99073a2a..0ee2a276e7ea 100644 --- a/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProvider.java +++ b/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProvider.java @@ -23,6 +23,7 @@ import org.apache.druid.collections.QueueNonBlockingPool; import org.apache.druid.collections.ReferenceCountingResourceHolder; import org.apache.druid.collections.ResourceHolder; +import org.apache.druid.error.DruidException; import org.apache.druid.frame.processor.Bouncer; import org.apache.druid.msq.exec.ProcessingBuffers; import org.apache.druid.msq.exec.ProcessingBuffersProvider; @@ -39,7 +40,7 @@ /** * Production implementation of {@link ProcessingBuffersProvider} that uses the merge buffer pool. Each call - * to {@link #acquire(int)} acquires one merge buffer and slices it up. + * to {@link #acquire(int, long)} acquires one merge buffer and slices it up. */ public class DartProcessingBuffersProvider implements ProcessingBuffersProvider { @@ -67,27 +68,55 @@ public ResourceHolder acquire(final int poolSize, final lo final ReferenceCountingResourceHolder bufferHolder = batch.get(0); try { final ByteBuffer buffer = bufferHolder.get().duplicate(); - final int sliceSize = buffer.capacity() / poolSize / processingThreads; - final List pool = new ArrayList<>(poolSize); + final int chunkSize = buffer.capacity() / poolSize; + final List slots = new ArrayList<>(poolSize); for (int i = 0; i < poolSize; i++) { - final BlockingQueue queue = new ArrayBlockingQueue<>(processingThreads); - for (int j = 0; j < processingThreads; j++) { - final int sliceNum = i * processingThreads + j; - buffer.position(sliceSize * sliceNum).limit(sliceSize * (sliceNum + 1)); - queue.add(buffer.slice()); - } - final ProcessingBuffers buffers = new ProcessingBuffers( - new QueueNonBlockingPool<>(queue), - new Bouncer(processingThreads) - ); - pool.add(buffers); + buffer.position(chunkSize * i).limit(chunkSize * (i + 1)); + slots.add(new LazySlot(buffer.slice(), processingThreads)); } - return new ReferenceCountingResourceHolder<>(new ProcessingBuffersSet(pool), bufferHolder); + return new ReferenceCountingResourceHolder<>(new ProcessingBuffersSet(slots), bufferHolder); } catch (Throwable e) { throw CloseableUtils.closeAndWrapInCatch(e, bufferHolder); } } + + /** + * Lazy slot that holds one chunk of the shared merge buffer and slices it on demand to match the stage's + * actual concurrent-processor count. + */ + static final class LazySlot implements ProcessingBuffersSet.Slot + { + private final ByteBuffer chunk; + private final int maxSlices; + + LazySlot(final ByteBuffer chunk, final int maxSlices) + { + this.chunk = chunk; + this.maxSlices = maxSlices; + } + + @Override + public ProcessingBuffers acquire(final int requestedSlices) + { + if (requestedSlices > maxSlices) { + throw DruidException.defensive( + "requestedSlices[%d] too large for maxSlices[%d]", + requestedSlices, + maxSlices + ); + } + + final int sliceSize = chunk.capacity() / requestedSlices; + final BlockingQueue queue = new ArrayBlockingQueue<>(requestedSlices); + final ByteBuffer working = chunk.duplicate(); + for (int j = 0; j < requestedSlices; j++) { + working.position(sliceSize * j).limit(sliceSize * (j + 1)); + queue.add(working.slice()); + } + return new ProcessingBuffers(new QueueNonBlockingPool<>(queue), new Bouncer(requestedSlices)); + } + } } diff --git a/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerContext.java b/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerContext.java index bde9f8968a0d..768801261d1b 100644 --- a/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerContext.java +++ b/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerContext.java @@ -250,10 +250,9 @@ public FrameContext frameContext(WorkOrder workOrder) this, FrameWriterSpec.fromContext(workOrder.getWorkerContext()), segmentWrangler, - groupingEngine, segmentManager, coordinatorClient, - processingBuffersSet.get().acquireForStage(workOrder.getStageDefinition()), + workOrder.getStageDefinition().getProcessor().usesProcessingBuffers() ? processingBuffersSet.get() : null, memoryParameters, storageParameters, dataServerQueryHandlerFactory diff --git a/multi-stage-query/src/main/java/org/apache/druid/msq/exec/FrameContext.java b/multi-stage-query/src/main/java/org/apache/druid/msq/exec/FrameContext.java index f093cb0b8aac..1a3b99538d1d 100644 --- a/multi-stage-query/src/main/java/org/apache/druid/msq/exec/FrameContext.java +++ b/multi-stage-query/src/main/java/org/apache/druid/msq/exec/FrameContext.java @@ -84,6 +84,17 @@ public interface FrameContext extends Closeable IndexMerger indexMerger(); + /** + * Acquire processing buffers sized for {@code requestedSlices} concurrent processors. Must be called exactly + * once for stages that use processing buffers, before any call to {@link #processingBuffers()}. Stages that + * don't use processing buffers must not call this method. + */ + void acquireProcessingBuffers(int requestedSlices); + + /** + * Returns the {@link ProcessingBuffers} previously acquired via {@link #acquireProcessingBuffers}. Throws if + * not yet acquired. + */ ProcessingBuffers processingBuffers(); WorkerMemoryParameters memoryParameters(); diff --git a/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ProcessingBuffersSet.java b/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ProcessingBuffersSet.java index 26dea9da9433..5d5b367b3bc4 100644 --- a/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ProcessingBuffersSet.java +++ b/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ProcessingBuffersSet.java @@ -21,7 +21,6 @@ import org.apache.druid.collections.ResourceHolder; import org.apache.druid.error.DruidException; -import org.apache.druid.msq.kernel.StageDefinition; import java.nio.ByteBuffer; import java.util.Collection; @@ -31,27 +30,44 @@ import java.util.stream.Collectors; /** - * Holds a set of {@link ProcessingBuffers} for a {@link Worker}. Acquired from {@link ProcessingBuffersProvider}. + * Holds a set of {@link Slot}, each of which can produce {@link ProcessingBuffers} for one concurrent stage. + * Acquired from {@link ProcessingBuffersProvider}. + * + * Slots come in two flavors: + *
    + *
  • {@link EagerSlot}: holds an already-built {@link ProcessingBuffers}; ignores the requested slice count. + * Used by buffer providers that pre-allocate (Peon, Indexer).
  • + *
  • Lazy slots (provider-defined): hold a buffer chunk and slice it per stage based on the actual concurrent + * processor count, so a stage that runs fewer processors gets larger slices. Used by Dart.
  • + *
*/ public class ProcessingBuffersSet { public static final ProcessingBuffersSet EMPTY = new ProcessingBuffersSet(Collections.emptyList()); - private final BlockingQueue pool; + private final BlockingQueue pool; + + public ProcessingBuffersSet(final Collection slots) + { + this.pool = new ArrayBlockingQueue<>(slots.isEmpty() ? 1 : slots.size()); + this.pool.addAll(slots); + } - public ProcessingBuffersSet(Collection buffers) + /** + * Wrap a collection of pre-built {@link ProcessingBuffers}. + */ + public static ProcessingBuffersSet wrap(final Collection buffers) { - this.pool = new ArrayBlockingQueue<>(buffers.isEmpty() ? 1 : buffers.size()); - this.pool.addAll(buffers); + return new ProcessingBuffersSet(buffers.stream().map(EagerSlot::new).collect(Collectors.toList())); } /** * Equivalent to calling {@link ProcessingBuffers#fromCollection} on each collection in the overall collection, - * then creating an instance. + * then wrapping in eager slots. */ public static > ProcessingBuffersSet fromCollection(final Collection processingBuffers) { - return new ProcessingBuffersSet( + return wrap( processingBuffers.stream() .map(ProcessingBuffers::fromCollection) .collect(Collectors.toList()) @@ -59,32 +75,28 @@ public static > ProcessingBuffersSet fromCollec } /** - * Acquire buffers if a particular stages needs them; otherwise, returns a holder that throws an exception on - * {@link ResourceHolder#get()}. - */ - public ResourceHolder acquireForStage(final StageDefinition stageDef) - { - if (!stageDef.getProcessor().usesProcessingBuffers()) { - return new NilResourceHolder<>(); - } else { - return acquire(); - } - } - - /** - * Acquire buffers unconditionally. In production, it is expected that callers will use - * {@link #acquireForStage(StageDefinition)}. + * Acquire buffers with a specific requested slice count. The actual number of slices may be higher but will + * not be lower. */ - public ResourceHolder acquire() + public ResourceHolder acquire(final int requestedSlices) { - final ProcessingBuffers buffers = pool.poll(); + final Slot slot = pool.poll(); - if (buffers == null) { + if (slot == null) { // Never happens, because the pool acquired from ProcessingBuffersProvider must be big enough for all // concurrent processing buffer needs. (In other words: if this does happen, it's a bug.) throw DruidException.defensive("Processing buffers not available"); } + final ProcessingBuffers buffers; + try { + buffers = slot.acquire(requestedSlices); + } + catch (Throwable t) { + pool.add(slot); + throw t; + } + return new ResourceHolder<>() { @Override @@ -96,26 +108,49 @@ public ProcessingBuffers get() @Override public void close() { - pool.add(buffers); + pool.add(slot); } }; } /** - * Resource holder that throws an exception on {@link #get()}. + * A producer of {@link ProcessingBuffers} from a single concurrent-stage slot in the pool. Implementations + * decide whether the slice count argument to {@link #acquire} is honored (lazy slots) or ignored (eager slots). */ - static class NilResourceHolder implements ResourceHolder + public interface Slot { - @Override - public T get() + /** + * Produce a {@link ProcessingBuffers} suitable for a stage that will run up to {@code requestedSlices} + * concurrent processors. Implementations may choose to ignore the argument when the slot's buffers are + * already laid out (e.g., {@link EagerSlot}). + */ + ProcessingBuffers acquire(int requestedSlices); + } + + /** + * Slot that wraps an already-built {@link ProcessingBuffers}. + */ + public static final class EagerSlot implements Slot + { + private final ProcessingBuffers buffers; + + public EagerSlot(final ProcessingBuffers buffers) { - throw DruidException.defensive("Unexpected call to get()"); + this.buffers = buffers; } @Override - public void close() + public ProcessingBuffers acquire(final int requestedSlices) { - // Do nothing. + if (requestedSlices > buffers.getBouncer().getMaxCount()) { + throw DruidException.defensive( + "requestedSlices[%d] too large, only have[%d] buffers", + requestedSlices, + buffers.getBouncer().getMaxCount() + ); + } + + return buffers; } } } diff --git a/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java b/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java index cfae065176a0..13da1393c710 100644 --- a/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java +++ b/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java @@ -22,11 +22,13 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.druid.client.coordinator.CoordinatorClient; import org.apache.druid.collections.ResourceHolder; +import org.apache.druid.error.DruidException; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.msq.exec.DataServerQueryHandlerFactory; import org.apache.druid.msq.exec.FrameContext; import org.apache.druid.msq.exec.FrameWriterSpec; import org.apache.druid.msq.exec.ProcessingBuffers; +import org.apache.druid.msq.exec.ProcessingBuffersSet; import org.apache.druid.msq.exec.WorkerMemoryParameters; import org.apache.druid.msq.exec.WorkerStorageParameters; import org.apache.druid.msq.kernel.StageId; @@ -51,11 +53,22 @@ public class IndexerFrameContext implements FrameContext private final SegmentManager segmentManager; @Nullable private final CoordinatorClient coordinatorClient; - private final ResourceHolder processingBuffers; + + /** + * Null if the stage does not use processing buffers. + */ + @Nullable + private final ProcessingBuffersSet processingBuffersSet; private final WorkerMemoryParameters memoryParameters; private final WorkerStorageParameters storageParameters; private final IndexerDataServerQueryHandlerFactory dataServerQueryHandlerFactory; + /** + * Acquired by {@link #acquireProcessingBuffers}. + */ + @Nullable + private ResourceHolder processingBuffers; + public IndexerFrameContext( StageId stageId, IndexerWorkerContext context, @@ -63,7 +76,7 @@ public IndexerFrameContext( IndexIO indexIO, SegmentManager segmentManager, @Nullable CoordinatorClient coordinatorClient, - ResourceHolder processingBuffers, + @Nullable ProcessingBuffersSet processingBuffersSet, IndexerDataServerQueryHandlerFactory dataServerQueryHandlerFactory, WorkerMemoryParameters memoryParameters, WorkerStorageParameters storageParameters @@ -75,7 +88,7 @@ public IndexerFrameContext( this.indexIO = indexIO; this.segmentManager = segmentManager; this.coordinatorClient = coordinatorClient; - this.processingBuffers = processingBuffers; + this.processingBuffersSet = processingBuffersSet; this.memoryParameters = memoryParameters; this.storageParameters = storageParameters; this.dataServerQueryHandlerFactory = dataServerQueryHandlerFactory; @@ -162,9 +175,24 @@ public IndexMerger indexMerger() return context.toolbox().getIndexMerger(); } + @Override + public void acquireProcessingBuffers(final int requestedSlices) + { + if (processingBuffersSet == null) { + throw DruidException.defensive("Stage[%s] does not use processing buffers", stageId); + } + if (processingBuffers != null) { + throw DruidException.defensive("Processing buffers already acquired"); + } + processingBuffers = processingBuffersSet.acquire(requestedSlices); + } + @Override public ProcessingBuffers processingBuffers() { + if (processingBuffers == null) { + throw DruidException.defensive("Processing buffers not yet acquired"); + } return processingBuffers.get(); } @@ -189,6 +217,8 @@ public FrameWriterSpec frameWriterSpec() @Override public void close() { - processingBuffers.close(); + if (processingBuffers != null) { + processingBuffers.close(); + } } } diff --git a/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java b/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java index 3a50cdc71f20..2589d0c0dab7 100644 --- a/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java +++ b/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java @@ -304,7 +304,7 @@ public FrameContext frameContext(WorkOrder workOrder) indexIO, segmentManager, coordinatorClient, - processingBuffersSet.get().acquireForStage(workOrder.getStageDefinition()), + workOrder.getStageDefinition().getProcessor().usesProcessingBuffers() ? processingBuffersSet.get() : null, dataServerQueryHandlerFactory, memoryParameters, WorkerStorageParameters.createProductionInstance(injector, workOrder.getOutputChannelMode()) diff --git a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafStageProcessor.java b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafStageProcessor.java index 240e0b9e75ff..82d344598f24 100644 --- a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafStageProcessor.java +++ b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafStageProcessor.java @@ -116,6 +116,10 @@ public ListenableFuture execute(ExecutionContext context) outstandingProcessors = Math.min(totalProcessors, context.threadCount()); } + if (usesProcessingBuffers()) { + frameContext.acquireProcessingBuffers(outstandingProcessors); + } + final Queue frameWriterFactoryQueue = new ArrayDeque<>(outstandingProcessors); final Queue channelQueue = new ArrayDeque<>(outstandingProcessors); final List outputChannels = new ArrayList<>(outstandingProcessors); diff --git a/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProviderTest.java b/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProviderTest.java index 75afddf3fca7..199cffbd1f7a 100644 --- a/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProviderTest.java +++ b/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProviderTest.java @@ -44,7 +44,7 @@ @RunWith(MockitoJUnitRunner.class) public class DartProcessingBuffersProviderTest { - private static final int PROCESSING_THREADS = 2; + private static final int PROCESSING_THREADS = 4; private static final long TIMEOUT_MILLIS = 1000L; private static final int BUFFER_SIZE = 1024; @@ -76,15 +76,107 @@ public void test_acquire_poolSizeZero() result.close(); } + @Test + public void test_acquire_singleSliceUsesFullChunk() + { + // With poolSize=1 and one merge buffer of BUFFER_SIZE, the chunk is BUFFER_SIZE. + // Requesting 1 slice should give a single buffer of full chunk size and a Bouncer of 1. + when(mockBufferHolder.get()).thenReturn(testBuffer); + when(mockMergeBufferPool.takeBatch(eq(1), eq(TIMEOUT_MILLIS))) + .thenReturn(List.of(mockBufferHolder)); + + final ResourceHolder result = provider.acquire(1, TIMEOUT_MILLIS); + try { + final ResourceHolder holder = result.get().acquire(1); + try { + final ProcessingBuffers buffers = holder.get(); + Assert.assertEquals(1, buffers.getBouncer().getMaxCount()); + + final ResourceHolder sliceHolder = buffers.getBufferPool().take(); + Assert.assertEquals(BUFFER_SIZE, sliceHolder.get().capacity()); + sliceHolder.close(); + } + finally { + holder.close(); + } + } + finally { + result.close(); + } + } + + @Test + public void test_acquire_processingThreadsSlices() + { + // Requesting PROCESSING_THREADS slices yields the maximum slicing: each slice is BUFFER_SIZE/PROCESSING_THREADS. + when(mockBufferHolder.get()).thenReturn(testBuffer); + when(mockMergeBufferPool.takeBatch(eq(1), eq(TIMEOUT_MILLIS))) + .thenReturn(List.of(mockBufferHolder)); + + final ResourceHolder result = provider.acquire(1, TIMEOUT_MILLIS); + try { + final ResourceHolder holder = result.get().acquire(PROCESSING_THREADS); + try { + final ProcessingBuffers buffers = holder.get(); + Assert.assertEquals(PROCESSING_THREADS, buffers.getBouncer().getMaxCount()); + + final List> sliceHolders = new ArrayList<>(); + try { + for (int i = 0; i < PROCESSING_THREADS; i++) { + final ResourceHolder sliceHolder = buffers.getBufferPool().take(); + Assert.assertEquals(BUFFER_SIZE / PROCESSING_THREADS, sliceHolder.get().capacity()); + sliceHolders.add(sliceHolder); + } + } + finally { + for (final ResourceHolder sh : sliceHolders) { + sh.close(); + } + } + } + finally { + holder.close(); + } + } + finally { + result.close(); + } + } + + @Test + public void test_acquire_resliceAfterRelease() + { + // Acquire with N=2, release, then re-acquire with N=4. The chunk should be re-sliced. + when(mockBufferHolder.get()).thenReturn(testBuffer); + when(mockMergeBufferPool.takeBatch(eq(1), eq(TIMEOUT_MILLIS))) + .thenReturn(List.of(mockBufferHolder)); + + final ResourceHolder result = provider.acquire(1, TIMEOUT_MILLIS); + try { + // First acquisition with 2 slices. + final ResourceHolder holder1 = result.get().acquire(2); + Assert.assertEquals(2, holder1.get().getBouncer().getMaxCount()); + Assert.assertEquals(BUFFER_SIZE / 2, holder1.get().getBufferPool().take().get().capacity()); + holder1.close(); + + // Second acquisition with 4 slices — same chunk, different slicing. + final ResourceHolder holder2 = result.get().acquire(4); + Assert.assertEquals(4, holder2.get().getBouncer().getMaxCount()); + Assert.assertEquals(BUFFER_SIZE / 4, holder2.get().getBufferPool().take().get().capacity()); + holder2.close(); + } + finally { + result.close(); + } + } + @Test public void test_acquire_poolSizeTwo() { - // Setup mock to return a buffer when(mockBufferHolder.get()).thenReturn(testBuffer); when(mockMergeBufferPool.takeBatch(eq(1), eq(TIMEOUT_MILLIS))) .thenReturn(List.of(mockBufferHolder)); - // Test successful acquisition final int poolSize = 2; final ResourceHolder result = provider.acquire(poolSize, TIMEOUT_MILLIS); @@ -92,33 +184,28 @@ public void test_acquire_poolSizeTwo() final ProcessingBuffersSet buffersSet = result.get(); Assert.assertNotNull(buffersSet); - // Verify we can acquire buffers from the set + // Each slot's chunk has capacity BUFFER_SIZE/poolSize. Requesting PROCESSING_THREADS slices yields slices + // of size (BUFFER_SIZE/poolSize)/PROCESSING_THREADS. for (int i = 0; i < poolSize; i++) { - final ResourceHolder buffersHolder = buffersSet.acquire(); - Assert.assertNotNull(buffersHolder); - - final ProcessingBuffers buffers = buffersHolder.get(); - Assert.assertNotNull(buffers); - Assert.assertNotNull(buffers.getBufferPool()); - Assert.assertNotNull(buffers.getBouncer()); - - // The bouncer should have the correct max count (PROCESSING_THREADS) - Assert.assertEquals(PROCESSING_THREADS, buffers.getBouncer().getMaxCount()); - - // Verify that we can get processing threads number of buffers - final List> resourceHolders = new ArrayList<>(); - for (int j = 0; j < PROCESSING_THREADS; j++) { - final ResourceHolder bufferResource = buffers.getBufferPool().take(); - Assert.assertNotNull(bufferResource); - Assert.assertNotNull(bufferResource.get()); - resourceHolders.add(bufferResource); + final ResourceHolder buffersHolder = buffersSet.acquire(PROCESSING_THREADS); + try { + final ProcessingBuffers buffers = buffersHolder.get(); + Assert.assertEquals(PROCESSING_THREADS, buffers.getBouncer().getMaxCount()); + + final int expectedSliceSize = BUFFER_SIZE / poolSize / PROCESSING_THREADS; + final List> resourceHolders = new ArrayList<>(); + for (int j = 0; j < PROCESSING_THREADS; j++) { + final ResourceHolder sliceHolder = buffers.getBufferPool().take(); + Assert.assertEquals(expectedSliceSize, sliceHolder.get().capacity()); + resourceHolders.add(sliceHolder); + } + for (final ResourceHolder resourceHolder : resourceHolders) { + resourceHolder.close(); + } } - - for (final ResourceHolder resourceHolder : resourceHolders) { - resourceHolder.close(); + finally { + buffersHolder.close(); } - - buffersHolder.close(); // Return to pool } result.close(); @@ -127,7 +214,6 @@ public void test_acquire_poolSizeTwo() @Test public void test_acquire_timeout() { - // Setup mock pool to return empty list (as happens during a timeout) when(mockMergeBufferPool.takeBatch(eq(1), eq(TIMEOUT_MILLIS))) .thenReturn(Collections.emptyList()); diff --git a/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ProcessingBuffersSetTest.java b/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ProcessingBuffersSetTest.java index dd30532e4654..ba78a35f80ea 100644 --- a/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ProcessingBuffersSetTest.java +++ b/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ProcessingBuffersSetTest.java @@ -22,19 +22,13 @@ import com.google.common.collect.ImmutableList; import org.apache.druid.collections.ResourceHolder; import org.apache.druid.error.DruidException; -import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.utils.CloseableUtils; -import org.hamcrest.CoreMatchers; -import org.hamcrest.MatcherAssert; import org.junit.Assert; import org.junit.Test; -import org.mockito.Mockito; import java.io.IOException; import java.nio.ByteBuffer; -import java.util.Collections; import java.util.List; -import java.util.NoSuchElementException; public class ProcessingBuffersSetTest { @@ -43,7 +37,7 @@ public void test_empty_acquire() { final DruidException e = Assert.assertThrows( DruidException.class, - ProcessingBuffersSet.EMPTY::acquire + () -> ProcessingBuffersSet.EMPTY.acquire(1) ); Assert.assertEquals("Processing buffers not available", e.getMessage()); @@ -66,9 +60,9 @@ public void test_fromCollection() throws IOException final ProcessingBuffersSet buffersSet = ProcessingBuffersSet.fromCollection(bufferLists); // Should be able to acquire all three - final ResourceHolder holder1 = buffersSet.acquire(); - final ResourceHolder holder2 = buffersSet.acquire(); - final ResourceHolder holder3 = buffersSet.acquire(); + final ResourceHolder holder1 = buffersSet.acquire(1); + final ResourceHolder holder2 = buffersSet.acquire(1); + final ResourceHolder holder3 = buffersSet.acquire(1); Assert.assertNotNull(holder1.get()); Assert.assertNotNull(holder2.get()); @@ -86,64 +80,4 @@ public void test_fromCollection() throws IOException CloseableUtils.closeAll(holder1, holder2, holder3); } - @Test - public void test_nilResourceHolder() - { - final ProcessingBuffersSet.NilResourceHolder nilHolder = new ProcessingBuffersSet.NilResourceHolder<>(); - - final DruidException e = Assert.assertThrows( - DruidException.class, - nilHolder::get - ); - - Assert.assertEquals("Unexpected call to get()", e.getMessage()); - - nilHolder.close(); // Should do nothing - } - - @Test - public void test_acquireForStage_usesProcessingBuffersFalse() - { - // Create a mock StageDefinition and StageProcessor - final StageDefinition stageDef = Mockito.mock(StageDefinition.class); - final StageProcessor stageProcessor = Mockito.mock(StageProcessor.class); - - // Configure mocks: processor factory does not use processing buffers - Mockito.when(stageDef.getProcessor()).thenReturn(stageProcessor); - Mockito.when(stageProcessor.usesProcessingBuffers()).thenReturn(false); - - // Create a ProcessingBuffersSet - final ProcessingBuffersSet buffersSet = - ProcessingBuffersSet.fromCollection( - Collections.singletonList( - Collections.singletonList(ByteBuffer.allocate(1024)))); - - // Acquire for stage - final ResourceHolder holder = buffersSet.acquireForStage(stageDef); - MatcherAssert.assertThat(holder, CoreMatchers.instanceOf(ProcessingBuffersSet.NilResourceHolder.class)); - } - - @Test - public void test_acquireForStage_usesProcessingBuffersTrue() - { - // Create a mock StageDefinition and StageProcessor - final StageDefinition stageDef = Mockito.mock(StageDefinition.class); - final StageProcessor stageProcessor = Mockito.mock(StageProcessor.class); - - // Configure mocks: processor factory does use processing buffers - Mockito.when(stageDef.getProcessor()).thenReturn(stageProcessor); - Mockito.when(stageProcessor.usesProcessingBuffers()).thenReturn(true); - - // Create a ProcessingBuffersSet - final ProcessingBuffersSet buffersSet = - ProcessingBuffersSet.fromCollection( - Collections.singletonList( - Collections.singletonList(ByteBuffer.allocate(1024)))); - - // Acquire for stage - final ResourceHolder holder = buffersSet.acquireForStage(stageDef); - final ProcessingBuffers buffers = holder.get(); - Assert.assertEquals(1024, buffers.getBufferPool().take().get().capacity()); - Assert.assertThrows(NoSuchElementException.class, () -> buffers.getBufferPool().take()); - } } diff --git a/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java b/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java index a8773635484c..02ae7c56e1a8 100644 --- a/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java +++ b/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java @@ -317,6 +317,12 @@ public IndexMerger indexMerger() ); } + @Override + public void acquireProcessingBuffers(final int requestedSlices) + { + // No-op: this mock returns a fixed ProcessingBuffers regardless of slice count. + } + @Override public ProcessingBuffers processingBuffers() { From e21ccda7436e5dafae93969074866b557049ebd9 Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Thu, 7 May 2026 10:56:21 -0700 Subject: [PATCH 2/4] Need to acquire buffers even when there are zero processors. --- .../druid/msq/querykit/BaseLeafStageProcessor.java | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafStageProcessor.java b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafStageProcessor.java index 82d344598f24..ec3884554e18 100644 --- a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafStageProcessor.java +++ b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafStageProcessor.java @@ -100,13 +100,12 @@ public ListenableFuture execute(ExecutionContext context) final ReadableInputQueue baseInputQueue = makeBaseInputQueue(context.workOrder().getInputs(), context); final int totalProcessors = baseInputQueue.remaining(); - if (totalProcessors == 0) { - return stageRunner.run(new ProcessorsAndChannels<>(ProcessorManagers.none(), OutputChannels.none())); - } - final int outstandingProcessors; - if (hasParquet(inputSlices)) { + if (totalProcessors == 0) { + // No processors to run, but still acquire 1 slice so processingBouncer() works in stageRunner.run(). + outstandingProcessors = 1; + } else if (hasParquet(inputSlices)) { // This is a workaround for memory use in ParquetFileReader, which loads up an entire row group into memory as // part of its normal operation. Row groups can be quite large (like, 1GB large) so this is a major source of // unaccounted-for memory use during ingestion and query of external data. We are trying to prevent memory @@ -120,6 +119,10 @@ public ListenableFuture execute(ExecutionContext context) frameContext.acquireProcessingBuffers(outstandingProcessors); } + if (totalProcessors == 0) { + return stageRunner.run(new ProcessorsAndChannels<>(ProcessorManagers.none(), OutputChannels.none())); + } + final Queue frameWriterFactoryQueue = new ArrayDeque<>(outstandingProcessors); final Queue channelQueue = new ArrayDeque<>(outstandingProcessors); final List outputChannels = new ArrayList<>(outstandingProcessors); From aa0f05c0f0a3d35bbca8d60025957fecd27bec5e Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Thu, 7 May 2026 11:59:42 -0700 Subject: [PATCH 3/4] Add tests. --- .../msq/dart/worker/DartFrameContextTest.java | 144 ++++++++++++++++++ .../msq/indexing/IndexerFrameContextTest.java | 144 ++++++++++++++++++ 2 files changed, 288 insertions(+) create mode 100644 multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartFrameContextTest.java create mode 100644 multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerFrameContextTest.java diff --git a/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartFrameContextTest.java b/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartFrameContextTest.java new file mode 100644 index 000000000000..056115d87a2d --- /dev/null +++ b/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartFrameContextTest.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.google.common.collect.ImmutableList; +import org.apache.druid.error.DruidException; +import org.apache.druid.msq.exec.ProcessingBuffers; +import org.apache.druid.msq.exec.ProcessingBuffersSet; +import org.apache.druid.msq.kernel.StageId; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.nio.ByteBuffer; + +public class DartFrameContextTest +{ + private static final StageId STAGE_ID = new StageId("query-1", 0); + + private ProcessingBuffersSet buffersSet; + + @Before + public void setUp() + { + final ByteBuffer buffer = ByteBuffer.allocate(1024); + buffersSet = ProcessingBuffersSet.fromCollection(ImmutableList.of(ImmutableList.of(buffer))); + } + + @Test + public void test_acquireProcessingBuffers_nullSet_throws() + { + final DartFrameContext context = makeContext(null); + + final DruidException e = Assert.assertThrows( + DruidException.class, + () -> context.acquireProcessingBuffers(1) + ); + + Assert.assertEquals( + "Stage[" + STAGE_ID + "] does not use processing buffers", + e.getMessage() + ); + } + + @Test + public void test_acquireProcessingBuffers_alreadyAcquired_throws() + { + final DartFrameContext context = makeContext(buffersSet); + context.acquireProcessingBuffers(1); + + final DruidException e = Assert.assertThrows( + DruidException.class, + () -> context.acquireProcessingBuffers(1) + ); + + Assert.assertEquals("Processing buffers already acquired", e.getMessage()); + + context.close(); + } + + @Test + public void test_processingBuffers_notAcquired_throws() + { + final DartFrameContext context = makeContext(buffersSet); + + final DruidException e = Assert.assertThrows( + DruidException.class, + context::processingBuffers + ); + + Assert.assertEquals("Processing buffers not yet acquired", e.getMessage()); + } + + @Test + public void test_processingBuffers_afterAcquire_returnsBuffers() + { + final DartFrameContext context = makeContext(buffersSet); + context.acquireProcessingBuffers(1); + + final ProcessingBuffers buffers = context.processingBuffers(); + Assert.assertNotNull(buffers); + Assert.assertNotNull(buffers.getBufferPool()); + Assert.assertNotNull(buffers.getBouncer()); + + context.close(); + } + + @Test + public void test_close_withoutAcquire_isNoop() + { + final DartFrameContext context = makeContext(buffersSet); + + // Should not throw. + context.close(); + + // Slot was never acquired, so it should still be available. + buffersSet.acquire(1).close(); + } + + @Test + public void test_close_afterAcquire_releasesSlot() + { + final DartFrameContext context = makeContext(buffersSet); + context.acquireProcessingBuffers(1); + + context.close(); + + // Slot should now be back in the pool and re-acquirable. + buffersSet.acquire(1).close(); + } + + private static DartFrameContext makeContext(final ProcessingBuffersSet processingBuffersSet) + { + return new DartFrameContext( + STAGE_ID, + null, + null, + null, + null, + null, + processingBuffersSet, + null, + null, + null + ); + } +} diff --git a/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerFrameContextTest.java b/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerFrameContextTest.java new file mode 100644 index 000000000000..6832198173bf --- /dev/null +++ b/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerFrameContextTest.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.indexing; + +import com.google.common.collect.ImmutableList; +import org.apache.druid.error.DruidException; +import org.apache.druid.msq.exec.ProcessingBuffers; +import org.apache.druid.msq.exec.ProcessingBuffersSet; +import org.apache.druid.msq.kernel.StageId; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.nio.ByteBuffer; + +public class IndexerFrameContextTest +{ + private static final StageId STAGE_ID = new StageId("query-1", 0); + + private ProcessingBuffersSet buffersSet; + + @Before + public void setUp() + { + final ByteBuffer buffer = ByteBuffer.allocate(1024); + buffersSet = ProcessingBuffersSet.fromCollection(ImmutableList.of(ImmutableList.of(buffer))); + } + + @Test + public void test_acquireProcessingBuffers_nullSet_throws() + { + final IndexerFrameContext context = makeContext(null); + + final DruidException e = Assert.assertThrows( + DruidException.class, + () -> context.acquireProcessingBuffers(1) + ); + + Assert.assertEquals( + "Stage[" + STAGE_ID + "] does not use processing buffers", + e.getMessage() + ); + } + + @Test + public void test_acquireProcessingBuffers_alreadyAcquired_throws() + { + final IndexerFrameContext context = makeContext(buffersSet); + context.acquireProcessingBuffers(1); + + final DruidException e = Assert.assertThrows( + DruidException.class, + () -> context.acquireProcessingBuffers(1) + ); + + Assert.assertEquals("Processing buffers already acquired", e.getMessage()); + + context.close(); + } + + @Test + public void test_processingBuffers_notAcquired_throws() + { + final IndexerFrameContext context = makeContext(buffersSet); + + final DruidException e = Assert.assertThrows( + DruidException.class, + context::processingBuffers + ); + + Assert.assertEquals("Processing buffers not yet acquired", e.getMessage()); + } + + @Test + public void test_processingBuffers_afterAcquire_returnsBuffers() + { + final IndexerFrameContext context = makeContext(buffersSet); + context.acquireProcessingBuffers(1); + + final ProcessingBuffers buffers = context.processingBuffers(); + Assert.assertNotNull(buffers); + Assert.assertNotNull(buffers.getBufferPool()); + Assert.assertNotNull(buffers.getBouncer()); + + context.close(); + } + + @Test + public void test_close_withoutAcquire_isNoop() + { + final IndexerFrameContext context = makeContext(buffersSet); + + // Should not throw. + context.close(); + + // Slot was never acquired, so it should still be available. + buffersSet.acquire(1).close(); + } + + @Test + public void test_close_afterAcquire_releasesSlot() + { + final IndexerFrameContext context = makeContext(buffersSet); + context.acquireProcessingBuffers(1); + + context.close(); + + // Slot should now be back in the pool and re-acquirable. + buffersSet.acquire(1).close(); + } + + private static IndexerFrameContext makeContext(final ProcessingBuffersSet processingBuffersSet) + { + return new IndexerFrameContext( + STAGE_ID, + null, + null, + null, + null, + null, + processingBuffersSet, + null, + null, + null + ); + } +} From 99cb6a6001714008579c65169d273e044ccdcacd Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Tue, 12 May 2026 10:13:43 -0700 Subject: [PATCH 4/4] Add defensive check. --- .../druid/msq/dart/worker/DartProcessingBuffersProvider.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProvider.java b/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProvider.java index 0ee2a276e7ea..d7808a7054a6 100644 --- a/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProvider.java +++ b/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProvider.java @@ -109,6 +109,10 @@ public ProcessingBuffers acquire(final int requestedSlices) ); } + if (requestedSlices < 1) { + throw DruidException.defensive("requestedSlices[%d] must be positive", requestedSlices); + } + final int sliceSize = chunk.capacity() / requestedSlices; final BlockingQueue queue = new ArrayBlockingQueue<>(requestedSlices); final ByteBuffer working = chunk.duplicate();