diff --git a/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartFrameContext.java b/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartFrameContext.java index d17a94754ab7..fca50d51efe0 100644 --- a/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartFrameContext.java +++ b/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartFrameContext.java @@ -27,11 +27,11 @@ import org.apache.druid.msq.exec.FrameContext; import org.apache.druid.msq.exec.FrameWriterSpec; import org.apache.druid.msq.exec.ProcessingBuffers; +import org.apache.druid.msq.exec.ProcessingBuffersSet; import org.apache.druid.msq.exec.WorkerContext; import org.apache.druid.msq.exec.WorkerMemoryParameters; import org.apache.druid.msq.exec.WorkerStorageParameters; import org.apache.druid.msq.kernel.StageId; -import org.apache.druid.query.groupby.GroupingEngine; import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.query.rowsandcols.serde.WireTransferableContext; import org.apache.druid.segment.IndexIO; @@ -42,6 +42,7 @@ import org.apache.druid.segment.loading.DataSegmentPusher; import org.apache.druid.server.SegmentManager; +import javax.annotation.Nullable; import java.io.File; /** @@ -52,24 +53,33 @@ public class DartFrameContext implements FrameContext private final StageId stageId; private final FrameWriterSpec frameWriterSpec; private final SegmentWrangler segmentWrangler; - private final GroupingEngine groupingEngine; private final SegmentManager segmentManager; private final CoordinatorClient coordinatorClient; private final WorkerContext workerContext; - private final ResourceHolder processingBuffers; + + /** + * Null if the stage does not use processing buffers. + */ + @Nullable + private final ProcessingBuffersSet processingBuffersSet; private final WorkerMemoryParameters memoryParameters; private final WorkerStorageParameters storageParameters; private final DataServerQueryHandlerFactory dataServerQueryHandlerFactory; + /** + * Acquired by {@link #acquireProcessingBuffers}. + */ + @Nullable + private ResourceHolder processingBuffers; + public DartFrameContext( final StageId stageId, final WorkerContext workerContext, final FrameWriterSpec frameWriterSpec, final SegmentWrangler segmentWrangler, - final GroupingEngine groupingEngine, final SegmentManager segmentManager, final CoordinatorClient coordinatorClient, - final ResourceHolder processingBuffers, + @Nullable final ProcessingBuffersSet processingBuffersSet, final WorkerMemoryParameters memoryParameters, final WorkerStorageParameters storageParameters, final DataServerQueryHandlerFactory dataServerQueryHandlerFactory @@ -78,11 +88,10 @@ public DartFrameContext( this.stageId = stageId; this.segmentWrangler = segmentWrangler; this.frameWriterSpec = frameWriterSpec; - this.groupingEngine = groupingEngine; this.segmentManager = segmentManager; this.coordinatorClient = coordinatorClient; this.workerContext = workerContext; - this.processingBuffers = processingBuffers; + this.processingBuffersSet = processingBuffersSet; this.memoryParameters = memoryParameters; this.storageParameters = storageParameters; this.dataServerQueryHandlerFactory = dataServerQueryHandlerFactory; @@ -160,9 +169,24 @@ public IndexMerger indexMerger() throw DruidException.defensive("Ingestion not implemented"); } + @Override + public void acquireProcessingBuffers(final int requestedSlices) + { + if (processingBuffersSet == null) { + throw DruidException.defensive("Stage[%s] does not use processing buffers", stageId); + } + if (processingBuffers != null) { + throw DruidException.defensive("Processing buffers already acquired"); + } + processingBuffers = processingBuffersSet.acquire(requestedSlices); + } + @Override public ProcessingBuffers processingBuffers() { + if (processingBuffers == null) { + throw DruidException.defensive("Processing buffers not yet acquired"); + } return processingBuffers.get(); } @@ -193,6 +217,8 @@ public FrameWriterSpec frameWriterSpec() @Override public void close() { - processingBuffers.close(); + if (processingBuffers != null) { + processingBuffers.close(); + } } } diff --git a/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProvider.java b/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProvider.java index 533f99073a2a..d7808a7054a6 100644 --- a/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProvider.java +++ b/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProvider.java @@ -23,6 +23,7 @@ import org.apache.druid.collections.QueueNonBlockingPool; import org.apache.druid.collections.ReferenceCountingResourceHolder; import org.apache.druid.collections.ResourceHolder; +import org.apache.druid.error.DruidException; import org.apache.druid.frame.processor.Bouncer; import org.apache.druid.msq.exec.ProcessingBuffers; import org.apache.druid.msq.exec.ProcessingBuffersProvider; @@ -39,7 +40,7 @@ /** * Production implementation of {@link ProcessingBuffersProvider} that uses the merge buffer pool. Each call - * to {@link #acquire(int)} acquires one merge buffer and slices it up. + * to {@link #acquire(int, long)} acquires one merge buffer and slices it up. */ public class DartProcessingBuffersProvider implements ProcessingBuffersProvider { @@ -67,27 +68,59 @@ public ResourceHolder acquire(final int poolSize, final lo final ReferenceCountingResourceHolder bufferHolder = batch.get(0); try { final ByteBuffer buffer = bufferHolder.get().duplicate(); - final int sliceSize = buffer.capacity() / poolSize / processingThreads; - final List pool = new ArrayList<>(poolSize); + final int chunkSize = buffer.capacity() / poolSize; + final List slots = new ArrayList<>(poolSize); for (int i = 0; i < poolSize; i++) { - final BlockingQueue queue = new ArrayBlockingQueue<>(processingThreads); - for (int j = 0; j < processingThreads; j++) { - final int sliceNum = i * processingThreads + j; - buffer.position(sliceSize * sliceNum).limit(sliceSize * (sliceNum + 1)); - queue.add(buffer.slice()); - } - final ProcessingBuffers buffers = new ProcessingBuffers( - new QueueNonBlockingPool<>(queue), - new Bouncer(processingThreads) - ); - pool.add(buffers); + buffer.position(chunkSize * i).limit(chunkSize * (i + 1)); + slots.add(new LazySlot(buffer.slice(), processingThreads)); } - return new ReferenceCountingResourceHolder<>(new ProcessingBuffersSet(pool), bufferHolder); + return new ReferenceCountingResourceHolder<>(new ProcessingBuffersSet(slots), bufferHolder); } catch (Throwable e) { throw CloseableUtils.closeAndWrapInCatch(e, bufferHolder); } } + + /** + * Lazy slot that holds one chunk of the shared merge buffer and slices it on demand to match the stage's + * actual concurrent-processor count. + */ + static final class LazySlot implements ProcessingBuffersSet.Slot + { + private final ByteBuffer chunk; + private final int maxSlices; + + LazySlot(final ByteBuffer chunk, final int maxSlices) + { + this.chunk = chunk; + this.maxSlices = maxSlices; + } + + @Override + public ProcessingBuffers acquire(final int requestedSlices) + { + if (requestedSlices > maxSlices) { + throw DruidException.defensive( + "requestedSlices[%d] too large for maxSlices[%d]", + requestedSlices, + maxSlices + ); + } + + if (requestedSlices < 1) { + throw DruidException.defensive("requestedSlices[%d] must be positive", requestedSlices); + } + + final int sliceSize = chunk.capacity() / requestedSlices; + final BlockingQueue queue = new ArrayBlockingQueue<>(requestedSlices); + final ByteBuffer working = chunk.duplicate(); + for (int j = 0; j < requestedSlices; j++) { + working.position(sliceSize * j).limit(sliceSize * (j + 1)); + queue.add(working.slice()); + } + return new ProcessingBuffers(new QueueNonBlockingPool<>(queue), new Bouncer(requestedSlices)); + } + } } diff --git a/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerContext.java b/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerContext.java index bde9f8968a0d..768801261d1b 100644 --- a/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerContext.java +++ b/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerContext.java @@ -250,10 +250,9 @@ public FrameContext frameContext(WorkOrder workOrder) this, FrameWriterSpec.fromContext(workOrder.getWorkerContext()), segmentWrangler, - groupingEngine, segmentManager, coordinatorClient, - processingBuffersSet.get().acquireForStage(workOrder.getStageDefinition()), + workOrder.getStageDefinition().getProcessor().usesProcessingBuffers() ? processingBuffersSet.get() : null, memoryParameters, storageParameters, dataServerQueryHandlerFactory diff --git a/multi-stage-query/src/main/java/org/apache/druid/msq/exec/FrameContext.java b/multi-stage-query/src/main/java/org/apache/druid/msq/exec/FrameContext.java index f093cb0b8aac..1a3b99538d1d 100644 --- a/multi-stage-query/src/main/java/org/apache/druid/msq/exec/FrameContext.java +++ b/multi-stage-query/src/main/java/org/apache/druid/msq/exec/FrameContext.java @@ -84,6 +84,17 @@ public interface FrameContext extends Closeable IndexMerger indexMerger(); + /** + * Acquire processing buffers sized for {@code requestedSlices} concurrent processors. Must be called exactly + * once for stages that use processing buffers, before any call to {@link #processingBuffers()}. Stages that + * don't use processing buffers must not call this method. + */ + void acquireProcessingBuffers(int requestedSlices); + + /** + * Returns the {@link ProcessingBuffers} previously acquired via {@link #acquireProcessingBuffers}. Throws if + * not yet acquired. + */ ProcessingBuffers processingBuffers(); WorkerMemoryParameters memoryParameters(); diff --git a/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ProcessingBuffersSet.java b/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ProcessingBuffersSet.java index 26dea9da9433..5d5b367b3bc4 100644 --- a/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ProcessingBuffersSet.java +++ b/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ProcessingBuffersSet.java @@ -21,7 +21,6 @@ import org.apache.druid.collections.ResourceHolder; import org.apache.druid.error.DruidException; -import org.apache.druid.msq.kernel.StageDefinition; import java.nio.ByteBuffer; import java.util.Collection; @@ -31,27 +30,44 @@ import java.util.stream.Collectors; /** - * Holds a set of {@link ProcessingBuffers} for a {@link Worker}. Acquired from {@link ProcessingBuffersProvider}. + * Holds a set of {@link Slot}, each of which can produce {@link ProcessingBuffers} for one concurrent stage. + * Acquired from {@link ProcessingBuffersProvider}. + * + * Slots come in two flavors: + *
    + *
  • {@link EagerSlot}: holds an already-built {@link ProcessingBuffers}; ignores the requested slice count. + * Used by buffer providers that pre-allocate (Peon, Indexer).
  • + *
  • Lazy slots (provider-defined): hold a buffer chunk and slice it per stage based on the actual concurrent + * processor count, so a stage that runs fewer processors gets larger slices. Used by Dart.
  • + *
*/ public class ProcessingBuffersSet { public static final ProcessingBuffersSet EMPTY = new ProcessingBuffersSet(Collections.emptyList()); - private final BlockingQueue pool; + private final BlockingQueue pool; + + public ProcessingBuffersSet(final Collection slots) + { + this.pool = new ArrayBlockingQueue<>(slots.isEmpty() ? 1 : slots.size()); + this.pool.addAll(slots); + } - public ProcessingBuffersSet(Collection buffers) + /** + * Wrap a collection of pre-built {@link ProcessingBuffers}. + */ + public static ProcessingBuffersSet wrap(final Collection buffers) { - this.pool = new ArrayBlockingQueue<>(buffers.isEmpty() ? 1 : buffers.size()); - this.pool.addAll(buffers); + return new ProcessingBuffersSet(buffers.stream().map(EagerSlot::new).collect(Collectors.toList())); } /** * Equivalent to calling {@link ProcessingBuffers#fromCollection} on each collection in the overall collection, - * then creating an instance. + * then wrapping in eager slots. */ public static > ProcessingBuffersSet fromCollection(final Collection processingBuffers) { - return new ProcessingBuffersSet( + return wrap( processingBuffers.stream() .map(ProcessingBuffers::fromCollection) .collect(Collectors.toList()) @@ -59,32 +75,28 @@ public static > ProcessingBuffersSet fromCollec } /** - * Acquire buffers if a particular stages needs them; otherwise, returns a holder that throws an exception on - * {@link ResourceHolder#get()}. - */ - public ResourceHolder acquireForStage(final StageDefinition stageDef) - { - if (!stageDef.getProcessor().usesProcessingBuffers()) { - return new NilResourceHolder<>(); - } else { - return acquire(); - } - } - - /** - * Acquire buffers unconditionally. In production, it is expected that callers will use - * {@link #acquireForStage(StageDefinition)}. + * Acquire buffers with a specific requested slice count. The actual number of slices may be higher but will + * not be lower. */ - public ResourceHolder acquire() + public ResourceHolder acquire(final int requestedSlices) { - final ProcessingBuffers buffers = pool.poll(); + final Slot slot = pool.poll(); - if (buffers == null) { + if (slot == null) { // Never happens, because the pool acquired from ProcessingBuffersProvider must be big enough for all // concurrent processing buffer needs. (In other words: if this does happen, it's a bug.) throw DruidException.defensive("Processing buffers not available"); } + final ProcessingBuffers buffers; + try { + buffers = slot.acquire(requestedSlices); + } + catch (Throwable t) { + pool.add(slot); + throw t; + } + return new ResourceHolder<>() { @Override @@ -96,26 +108,49 @@ public ProcessingBuffers get() @Override public void close() { - pool.add(buffers); + pool.add(slot); } }; } /** - * Resource holder that throws an exception on {@link #get()}. + * A producer of {@link ProcessingBuffers} from a single concurrent-stage slot in the pool. Implementations + * decide whether the slice count argument to {@link #acquire} is honored (lazy slots) or ignored (eager slots). */ - static class NilResourceHolder implements ResourceHolder + public interface Slot { - @Override - public T get() + /** + * Produce a {@link ProcessingBuffers} suitable for a stage that will run up to {@code requestedSlices} + * concurrent processors. Implementations may choose to ignore the argument when the slot's buffers are + * already laid out (e.g., {@link EagerSlot}). + */ + ProcessingBuffers acquire(int requestedSlices); + } + + /** + * Slot that wraps an already-built {@link ProcessingBuffers}. + */ + public static final class EagerSlot implements Slot + { + private final ProcessingBuffers buffers; + + public EagerSlot(final ProcessingBuffers buffers) { - throw DruidException.defensive("Unexpected call to get()"); + this.buffers = buffers; } @Override - public void close() + public ProcessingBuffers acquire(final int requestedSlices) { - // Do nothing. + if (requestedSlices > buffers.getBouncer().getMaxCount()) { + throw DruidException.defensive( + "requestedSlices[%d] too large, only have[%d] buffers", + requestedSlices, + buffers.getBouncer().getMaxCount() + ); + } + + return buffers; } } } diff --git a/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java b/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java index cfae065176a0..13da1393c710 100644 --- a/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java +++ b/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java @@ -22,11 +22,13 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.druid.client.coordinator.CoordinatorClient; import org.apache.druid.collections.ResourceHolder; +import org.apache.druid.error.DruidException; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.msq.exec.DataServerQueryHandlerFactory; import org.apache.druid.msq.exec.FrameContext; import org.apache.druid.msq.exec.FrameWriterSpec; import org.apache.druid.msq.exec.ProcessingBuffers; +import org.apache.druid.msq.exec.ProcessingBuffersSet; import org.apache.druid.msq.exec.WorkerMemoryParameters; import org.apache.druid.msq.exec.WorkerStorageParameters; import org.apache.druid.msq.kernel.StageId; @@ -51,11 +53,22 @@ public class IndexerFrameContext implements FrameContext private final SegmentManager segmentManager; @Nullable private final CoordinatorClient coordinatorClient; - private final ResourceHolder processingBuffers; + + /** + * Null if the stage does not use processing buffers. + */ + @Nullable + private final ProcessingBuffersSet processingBuffersSet; private final WorkerMemoryParameters memoryParameters; private final WorkerStorageParameters storageParameters; private final IndexerDataServerQueryHandlerFactory dataServerQueryHandlerFactory; + /** + * Acquired by {@link #acquireProcessingBuffers}. + */ + @Nullable + private ResourceHolder processingBuffers; + public IndexerFrameContext( StageId stageId, IndexerWorkerContext context, @@ -63,7 +76,7 @@ public IndexerFrameContext( IndexIO indexIO, SegmentManager segmentManager, @Nullable CoordinatorClient coordinatorClient, - ResourceHolder processingBuffers, + @Nullable ProcessingBuffersSet processingBuffersSet, IndexerDataServerQueryHandlerFactory dataServerQueryHandlerFactory, WorkerMemoryParameters memoryParameters, WorkerStorageParameters storageParameters @@ -75,7 +88,7 @@ public IndexerFrameContext( this.indexIO = indexIO; this.segmentManager = segmentManager; this.coordinatorClient = coordinatorClient; - this.processingBuffers = processingBuffers; + this.processingBuffersSet = processingBuffersSet; this.memoryParameters = memoryParameters; this.storageParameters = storageParameters; this.dataServerQueryHandlerFactory = dataServerQueryHandlerFactory; @@ -162,9 +175,24 @@ public IndexMerger indexMerger() return context.toolbox().getIndexMerger(); } + @Override + public void acquireProcessingBuffers(final int requestedSlices) + { + if (processingBuffersSet == null) { + throw DruidException.defensive("Stage[%s] does not use processing buffers", stageId); + } + if (processingBuffers != null) { + throw DruidException.defensive("Processing buffers already acquired"); + } + processingBuffers = processingBuffersSet.acquire(requestedSlices); + } + @Override public ProcessingBuffers processingBuffers() { + if (processingBuffers == null) { + throw DruidException.defensive("Processing buffers not yet acquired"); + } return processingBuffers.get(); } @@ -189,6 +217,8 @@ public FrameWriterSpec frameWriterSpec() @Override public void close() { - processingBuffers.close(); + if (processingBuffers != null) { + processingBuffers.close(); + } } } diff --git a/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java b/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java index 3a50cdc71f20..2589d0c0dab7 100644 --- a/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java +++ b/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java @@ -304,7 +304,7 @@ public FrameContext frameContext(WorkOrder workOrder) indexIO, segmentManager, coordinatorClient, - processingBuffersSet.get().acquireForStage(workOrder.getStageDefinition()), + workOrder.getStageDefinition().getProcessor().usesProcessingBuffers() ? processingBuffersSet.get() : null, dataServerQueryHandlerFactory, memoryParameters, WorkerStorageParameters.createProductionInstance(injector, workOrder.getOutputChannelMode()) diff --git a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafStageProcessor.java b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafStageProcessor.java index 240e0b9e75ff..ec3884554e18 100644 --- a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafStageProcessor.java +++ b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafStageProcessor.java @@ -100,13 +100,12 @@ public ListenableFuture execute(ExecutionContext context) final ReadableInputQueue baseInputQueue = makeBaseInputQueue(context.workOrder().getInputs(), context); final int totalProcessors = baseInputQueue.remaining(); - if (totalProcessors == 0) { - return stageRunner.run(new ProcessorsAndChannels<>(ProcessorManagers.none(), OutputChannels.none())); - } - final int outstandingProcessors; - if (hasParquet(inputSlices)) { + if (totalProcessors == 0) { + // No processors to run, but still acquire 1 slice so processingBouncer() works in stageRunner.run(). + outstandingProcessors = 1; + } else if (hasParquet(inputSlices)) { // This is a workaround for memory use in ParquetFileReader, which loads up an entire row group into memory as // part of its normal operation. Row groups can be quite large (like, 1GB large) so this is a major source of // unaccounted-for memory use during ingestion and query of external data. We are trying to prevent memory @@ -116,6 +115,14 @@ public ListenableFuture execute(ExecutionContext context) outstandingProcessors = Math.min(totalProcessors, context.threadCount()); } + if (usesProcessingBuffers()) { + frameContext.acquireProcessingBuffers(outstandingProcessors); + } + + if (totalProcessors == 0) { + return stageRunner.run(new ProcessorsAndChannels<>(ProcessorManagers.none(), OutputChannels.none())); + } + final Queue frameWriterFactoryQueue = new ArrayDeque<>(outstandingProcessors); final Queue channelQueue = new ArrayDeque<>(outstandingProcessors); final List outputChannels = new ArrayList<>(outstandingProcessors); diff --git a/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartFrameContextTest.java b/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartFrameContextTest.java new file mode 100644 index 000000000000..056115d87a2d --- /dev/null +++ b/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartFrameContextTest.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.google.common.collect.ImmutableList; +import org.apache.druid.error.DruidException; +import org.apache.druid.msq.exec.ProcessingBuffers; +import org.apache.druid.msq.exec.ProcessingBuffersSet; +import org.apache.druid.msq.kernel.StageId; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.nio.ByteBuffer; + +public class DartFrameContextTest +{ + private static final StageId STAGE_ID = new StageId("query-1", 0); + + private ProcessingBuffersSet buffersSet; + + @Before + public void setUp() + { + final ByteBuffer buffer = ByteBuffer.allocate(1024); + buffersSet = ProcessingBuffersSet.fromCollection(ImmutableList.of(ImmutableList.of(buffer))); + } + + @Test + public void test_acquireProcessingBuffers_nullSet_throws() + { + final DartFrameContext context = makeContext(null); + + final DruidException e = Assert.assertThrows( + DruidException.class, + () -> context.acquireProcessingBuffers(1) + ); + + Assert.assertEquals( + "Stage[" + STAGE_ID + "] does not use processing buffers", + e.getMessage() + ); + } + + @Test + public void test_acquireProcessingBuffers_alreadyAcquired_throws() + { + final DartFrameContext context = makeContext(buffersSet); + context.acquireProcessingBuffers(1); + + final DruidException e = Assert.assertThrows( + DruidException.class, + () -> context.acquireProcessingBuffers(1) + ); + + Assert.assertEquals("Processing buffers already acquired", e.getMessage()); + + context.close(); + } + + @Test + public void test_processingBuffers_notAcquired_throws() + { + final DartFrameContext context = makeContext(buffersSet); + + final DruidException e = Assert.assertThrows( + DruidException.class, + context::processingBuffers + ); + + Assert.assertEquals("Processing buffers not yet acquired", e.getMessage()); + } + + @Test + public void test_processingBuffers_afterAcquire_returnsBuffers() + { + final DartFrameContext context = makeContext(buffersSet); + context.acquireProcessingBuffers(1); + + final ProcessingBuffers buffers = context.processingBuffers(); + Assert.assertNotNull(buffers); + Assert.assertNotNull(buffers.getBufferPool()); + Assert.assertNotNull(buffers.getBouncer()); + + context.close(); + } + + @Test + public void test_close_withoutAcquire_isNoop() + { + final DartFrameContext context = makeContext(buffersSet); + + // Should not throw. + context.close(); + + // Slot was never acquired, so it should still be available. + buffersSet.acquire(1).close(); + } + + @Test + public void test_close_afterAcquire_releasesSlot() + { + final DartFrameContext context = makeContext(buffersSet); + context.acquireProcessingBuffers(1); + + context.close(); + + // Slot should now be back in the pool and re-acquirable. + buffersSet.acquire(1).close(); + } + + private static DartFrameContext makeContext(final ProcessingBuffersSet processingBuffersSet) + { + return new DartFrameContext( + STAGE_ID, + null, + null, + null, + null, + null, + processingBuffersSet, + null, + null, + null + ); + } +} diff --git a/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProviderTest.java b/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProviderTest.java index 75afddf3fca7..199cffbd1f7a 100644 --- a/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProviderTest.java +++ b/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProviderTest.java @@ -44,7 +44,7 @@ @RunWith(MockitoJUnitRunner.class) public class DartProcessingBuffersProviderTest { - private static final int PROCESSING_THREADS = 2; + private static final int PROCESSING_THREADS = 4; private static final long TIMEOUT_MILLIS = 1000L; private static final int BUFFER_SIZE = 1024; @@ -76,15 +76,107 @@ public void test_acquire_poolSizeZero() result.close(); } + @Test + public void test_acquire_singleSliceUsesFullChunk() + { + // With poolSize=1 and one merge buffer of BUFFER_SIZE, the chunk is BUFFER_SIZE. + // Requesting 1 slice should give a single buffer of full chunk size and a Bouncer of 1. + when(mockBufferHolder.get()).thenReturn(testBuffer); + when(mockMergeBufferPool.takeBatch(eq(1), eq(TIMEOUT_MILLIS))) + .thenReturn(List.of(mockBufferHolder)); + + final ResourceHolder result = provider.acquire(1, TIMEOUT_MILLIS); + try { + final ResourceHolder holder = result.get().acquire(1); + try { + final ProcessingBuffers buffers = holder.get(); + Assert.assertEquals(1, buffers.getBouncer().getMaxCount()); + + final ResourceHolder sliceHolder = buffers.getBufferPool().take(); + Assert.assertEquals(BUFFER_SIZE, sliceHolder.get().capacity()); + sliceHolder.close(); + } + finally { + holder.close(); + } + } + finally { + result.close(); + } + } + + @Test + public void test_acquire_processingThreadsSlices() + { + // Requesting PROCESSING_THREADS slices yields the maximum slicing: each slice is BUFFER_SIZE/PROCESSING_THREADS. + when(mockBufferHolder.get()).thenReturn(testBuffer); + when(mockMergeBufferPool.takeBatch(eq(1), eq(TIMEOUT_MILLIS))) + .thenReturn(List.of(mockBufferHolder)); + + final ResourceHolder result = provider.acquire(1, TIMEOUT_MILLIS); + try { + final ResourceHolder holder = result.get().acquire(PROCESSING_THREADS); + try { + final ProcessingBuffers buffers = holder.get(); + Assert.assertEquals(PROCESSING_THREADS, buffers.getBouncer().getMaxCount()); + + final List> sliceHolders = new ArrayList<>(); + try { + for (int i = 0; i < PROCESSING_THREADS; i++) { + final ResourceHolder sliceHolder = buffers.getBufferPool().take(); + Assert.assertEquals(BUFFER_SIZE / PROCESSING_THREADS, sliceHolder.get().capacity()); + sliceHolders.add(sliceHolder); + } + } + finally { + for (final ResourceHolder sh : sliceHolders) { + sh.close(); + } + } + } + finally { + holder.close(); + } + } + finally { + result.close(); + } + } + + @Test + public void test_acquire_resliceAfterRelease() + { + // Acquire with N=2, release, then re-acquire with N=4. The chunk should be re-sliced. + when(mockBufferHolder.get()).thenReturn(testBuffer); + when(mockMergeBufferPool.takeBatch(eq(1), eq(TIMEOUT_MILLIS))) + .thenReturn(List.of(mockBufferHolder)); + + final ResourceHolder result = provider.acquire(1, TIMEOUT_MILLIS); + try { + // First acquisition with 2 slices. + final ResourceHolder holder1 = result.get().acquire(2); + Assert.assertEquals(2, holder1.get().getBouncer().getMaxCount()); + Assert.assertEquals(BUFFER_SIZE / 2, holder1.get().getBufferPool().take().get().capacity()); + holder1.close(); + + // Second acquisition with 4 slices — same chunk, different slicing. + final ResourceHolder holder2 = result.get().acquire(4); + Assert.assertEquals(4, holder2.get().getBouncer().getMaxCount()); + Assert.assertEquals(BUFFER_SIZE / 4, holder2.get().getBufferPool().take().get().capacity()); + holder2.close(); + } + finally { + result.close(); + } + } + @Test public void test_acquire_poolSizeTwo() { - // Setup mock to return a buffer when(mockBufferHolder.get()).thenReturn(testBuffer); when(mockMergeBufferPool.takeBatch(eq(1), eq(TIMEOUT_MILLIS))) .thenReturn(List.of(mockBufferHolder)); - // Test successful acquisition final int poolSize = 2; final ResourceHolder result = provider.acquire(poolSize, TIMEOUT_MILLIS); @@ -92,33 +184,28 @@ public void test_acquire_poolSizeTwo() final ProcessingBuffersSet buffersSet = result.get(); Assert.assertNotNull(buffersSet); - // Verify we can acquire buffers from the set + // Each slot's chunk has capacity BUFFER_SIZE/poolSize. Requesting PROCESSING_THREADS slices yields slices + // of size (BUFFER_SIZE/poolSize)/PROCESSING_THREADS. for (int i = 0; i < poolSize; i++) { - final ResourceHolder buffersHolder = buffersSet.acquire(); - Assert.assertNotNull(buffersHolder); - - final ProcessingBuffers buffers = buffersHolder.get(); - Assert.assertNotNull(buffers); - Assert.assertNotNull(buffers.getBufferPool()); - Assert.assertNotNull(buffers.getBouncer()); - - // The bouncer should have the correct max count (PROCESSING_THREADS) - Assert.assertEquals(PROCESSING_THREADS, buffers.getBouncer().getMaxCount()); - - // Verify that we can get processing threads number of buffers - final List> resourceHolders = new ArrayList<>(); - for (int j = 0; j < PROCESSING_THREADS; j++) { - final ResourceHolder bufferResource = buffers.getBufferPool().take(); - Assert.assertNotNull(bufferResource); - Assert.assertNotNull(bufferResource.get()); - resourceHolders.add(bufferResource); + final ResourceHolder buffersHolder = buffersSet.acquire(PROCESSING_THREADS); + try { + final ProcessingBuffers buffers = buffersHolder.get(); + Assert.assertEquals(PROCESSING_THREADS, buffers.getBouncer().getMaxCount()); + + final int expectedSliceSize = BUFFER_SIZE / poolSize / PROCESSING_THREADS; + final List> resourceHolders = new ArrayList<>(); + for (int j = 0; j < PROCESSING_THREADS; j++) { + final ResourceHolder sliceHolder = buffers.getBufferPool().take(); + Assert.assertEquals(expectedSliceSize, sliceHolder.get().capacity()); + resourceHolders.add(sliceHolder); + } + for (final ResourceHolder resourceHolder : resourceHolders) { + resourceHolder.close(); + } } - - for (final ResourceHolder resourceHolder : resourceHolders) { - resourceHolder.close(); + finally { + buffersHolder.close(); } - - buffersHolder.close(); // Return to pool } result.close(); @@ -127,7 +214,6 @@ public void test_acquire_poolSizeTwo() @Test public void test_acquire_timeout() { - // Setup mock pool to return empty list (as happens during a timeout) when(mockMergeBufferPool.takeBatch(eq(1), eq(TIMEOUT_MILLIS))) .thenReturn(Collections.emptyList()); diff --git a/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ProcessingBuffersSetTest.java b/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ProcessingBuffersSetTest.java index dd30532e4654..ba78a35f80ea 100644 --- a/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ProcessingBuffersSetTest.java +++ b/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ProcessingBuffersSetTest.java @@ -22,19 +22,13 @@ import com.google.common.collect.ImmutableList; import org.apache.druid.collections.ResourceHolder; import org.apache.druid.error.DruidException; -import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.utils.CloseableUtils; -import org.hamcrest.CoreMatchers; -import org.hamcrest.MatcherAssert; import org.junit.Assert; import org.junit.Test; -import org.mockito.Mockito; import java.io.IOException; import java.nio.ByteBuffer; -import java.util.Collections; import java.util.List; -import java.util.NoSuchElementException; public class ProcessingBuffersSetTest { @@ -43,7 +37,7 @@ public void test_empty_acquire() { final DruidException e = Assert.assertThrows( DruidException.class, - ProcessingBuffersSet.EMPTY::acquire + () -> ProcessingBuffersSet.EMPTY.acquire(1) ); Assert.assertEquals("Processing buffers not available", e.getMessage()); @@ -66,9 +60,9 @@ public void test_fromCollection() throws IOException final ProcessingBuffersSet buffersSet = ProcessingBuffersSet.fromCollection(bufferLists); // Should be able to acquire all three - final ResourceHolder holder1 = buffersSet.acquire(); - final ResourceHolder holder2 = buffersSet.acquire(); - final ResourceHolder holder3 = buffersSet.acquire(); + final ResourceHolder holder1 = buffersSet.acquire(1); + final ResourceHolder holder2 = buffersSet.acquire(1); + final ResourceHolder holder3 = buffersSet.acquire(1); Assert.assertNotNull(holder1.get()); Assert.assertNotNull(holder2.get()); @@ -86,64 +80,4 @@ public void test_fromCollection() throws IOException CloseableUtils.closeAll(holder1, holder2, holder3); } - @Test - public void test_nilResourceHolder() - { - final ProcessingBuffersSet.NilResourceHolder nilHolder = new ProcessingBuffersSet.NilResourceHolder<>(); - - final DruidException e = Assert.assertThrows( - DruidException.class, - nilHolder::get - ); - - Assert.assertEquals("Unexpected call to get()", e.getMessage()); - - nilHolder.close(); // Should do nothing - } - - @Test - public void test_acquireForStage_usesProcessingBuffersFalse() - { - // Create a mock StageDefinition and StageProcessor - final StageDefinition stageDef = Mockito.mock(StageDefinition.class); - final StageProcessor stageProcessor = Mockito.mock(StageProcessor.class); - - // Configure mocks: processor factory does not use processing buffers - Mockito.when(stageDef.getProcessor()).thenReturn(stageProcessor); - Mockito.when(stageProcessor.usesProcessingBuffers()).thenReturn(false); - - // Create a ProcessingBuffersSet - final ProcessingBuffersSet buffersSet = - ProcessingBuffersSet.fromCollection( - Collections.singletonList( - Collections.singletonList(ByteBuffer.allocate(1024)))); - - // Acquire for stage - final ResourceHolder holder = buffersSet.acquireForStage(stageDef); - MatcherAssert.assertThat(holder, CoreMatchers.instanceOf(ProcessingBuffersSet.NilResourceHolder.class)); - } - - @Test - public void test_acquireForStage_usesProcessingBuffersTrue() - { - // Create a mock StageDefinition and StageProcessor - final StageDefinition stageDef = Mockito.mock(StageDefinition.class); - final StageProcessor stageProcessor = Mockito.mock(StageProcessor.class); - - // Configure mocks: processor factory does use processing buffers - Mockito.when(stageDef.getProcessor()).thenReturn(stageProcessor); - Mockito.when(stageProcessor.usesProcessingBuffers()).thenReturn(true); - - // Create a ProcessingBuffersSet - final ProcessingBuffersSet buffersSet = - ProcessingBuffersSet.fromCollection( - Collections.singletonList( - Collections.singletonList(ByteBuffer.allocate(1024)))); - - // Acquire for stage - final ResourceHolder holder = buffersSet.acquireForStage(stageDef); - final ProcessingBuffers buffers = holder.get(); - Assert.assertEquals(1024, buffers.getBufferPool().take().get().capacity()); - Assert.assertThrows(NoSuchElementException.class, () -> buffers.getBufferPool().take()); - } } diff --git a/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerFrameContextTest.java b/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerFrameContextTest.java new file mode 100644 index 000000000000..6832198173bf --- /dev/null +++ b/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerFrameContextTest.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.indexing; + +import com.google.common.collect.ImmutableList; +import org.apache.druid.error.DruidException; +import org.apache.druid.msq.exec.ProcessingBuffers; +import org.apache.druid.msq.exec.ProcessingBuffersSet; +import org.apache.druid.msq.kernel.StageId; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.nio.ByteBuffer; + +public class IndexerFrameContextTest +{ + private static final StageId STAGE_ID = new StageId("query-1", 0); + + private ProcessingBuffersSet buffersSet; + + @Before + public void setUp() + { + final ByteBuffer buffer = ByteBuffer.allocate(1024); + buffersSet = ProcessingBuffersSet.fromCollection(ImmutableList.of(ImmutableList.of(buffer))); + } + + @Test + public void test_acquireProcessingBuffers_nullSet_throws() + { + final IndexerFrameContext context = makeContext(null); + + final DruidException e = Assert.assertThrows( + DruidException.class, + () -> context.acquireProcessingBuffers(1) + ); + + Assert.assertEquals( + "Stage[" + STAGE_ID + "] does not use processing buffers", + e.getMessage() + ); + } + + @Test + public void test_acquireProcessingBuffers_alreadyAcquired_throws() + { + final IndexerFrameContext context = makeContext(buffersSet); + context.acquireProcessingBuffers(1); + + final DruidException e = Assert.assertThrows( + DruidException.class, + () -> context.acquireProcessingBuffers(1) + ); + + Assert.assertEquals("Processing buffers already acquired", e.getMessage()); + + context.close(); + } + + @Test + public void test_processingBuffers_notAcquired_throws() + { + final IndexerFrameContext context = makeContext(buffersSet); + + final DruidException e = Assert.assertThrows( + DruidException.class, + context::processingBuffers + ); + + Assert.assertEquals("Processing buffers not yet acquired", e.getMessage()); + } + + @Test + public void test_processingBuffers_afterAcquire_returnsBuffers() + { + final IndexerFrameContext context = makeContext(buffersSet); + context.acquireProcessingBuffers(1); + + final ProcessingBuffers buffers = context.processingBuffers(); + Assert.assertNotNull(buffers); + Assert.assertNotNull(buffers.getBufferPool()); + Assert.assertNotNull(buffers.getBouncer()); + + context.close(); + } + + @Test + public void test_close_withoutAcquire_isNoop() + { + final IndexerFrameContext context = makeContext(buffersSet); + + // Should not throw. + context.close(); + + // Slot was never acquired, so it should still be available. + buffersSet.acquire(1).close(); + } + + @Test + public void test_close_afterAcquire_releasesSlot() + { + final IndexerFrameContext context = makeContext(buffersSet); + context.acquireProcessingBuffers(1); + + context.close(); + + // Slot should now be back in the pool and re-acquirable. + buffersSet.acquire(1).close(); + } + + private static IndexerFrameContext makeContext(final ProcessingBuffersSet processingBuffersSet) + { + return new IndexerFrameContext( + STAGE_ID, + null, + null, + null, + null, + null, + processingBuffersSet, + null, + null, + null + ); + } +} diff --git a/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java b/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java index a8773635484c..02ae7c56e1a8 100644 --- a/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java +++ b/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java @@ -317,6 +317,12 @@ public IndexMerger indexMerger() ); } + @Override + public void acquireProcessingBuffers(final int requestedSlices) + { + // No-op: this mock returns a fixed ProcessingBuffers regardless of slice count. + } + @Override public ProcessingBuffers processingBuffers() {