Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ public class ParallelMergeCombiningSequence<T> extends YieldingSequenceBase<T>
public static final int DEFAULT_TASK_TARGET_RUN_TIME_MILLIS = 100;
public static final int DEFAULT_TASK_INITIAL_YIELD_NUM_ROWS = 16384;
public static final int DEFAULT_TASK_SMALL_BATCH_NUM_ROWS = 4096;
static final int DEFAULT_MAX_FAN_IN = 4;

private final ForkJoinPool workerPool;
private final List<Sequence<T>> inputSequences;
Expand Down Expand Up @@ -277,22 +278,24 @@ public void cleanup(Iterator<T> iterFromMake)

/**
* This {@link RecursiveAction} is the initial task of the parallel merge-combine process. Capacity and input sequence
* count permitting, it will partition the input set of {@link Sequence} to do 2 layer parallel merge.
* count permitting, it will partition the input set of {@link Sequence} into a multi-level merge tree with bounded
* fan-in ({@link #DEFAULT_MAX_FAN_IN}) at each level.
*
* For the first layer, the partitions of input sequences are each wrapped in {@link YielderBatchedResultsCursor}, and
* At the leaf level, partitions of input sequences are each wrapped in {@link YielderBatchedResultsCursor}, and
* for each partition a {@link PrepareMergeCombineInputsAction} will be executed to wait for each of the yielders to
* yield {@link ResultBatch}. After the cursors all have an initial set of results, the
* {@link PrepareMergeCombineInputsAction} will execute a {@link MergeCombineAction}
* to perform the actual work of merging sequences and combining results. The merged and combined output of each
* partition will itself be put into {@link ResultBatch} and pushed to a {@link BlockingQueue} with a
* {@link ForkJoinPool} {@link QueuePusher}.
*
* The second layer will execute a single {@link PrepareMergeCombineInputsAction} to wait for the {@link ResultBatch}
* from each partition to be available in their 'output' {@link BlockingQueue} which each is wrapped in
* {@link BlockingQueueuBatchedResultsCursor}. Like the first layer, after the {@link PrepareMergeCombineInputsAction}
* is complete and some {@link ResultBatch} are ready to merge from each partition, it will execute a
* {@link MergeCombineAction} do a final merge combine of all the parallel computed results, again pushing
* {@link ResultBatch} into a {@link BlockingQueue} with a {@link QueuePusher}.
* If the number of leaf outputs exceeds {@link #DEFAULT_MAX_FAN_IN}, intermediate merge levels are created. Each
* intermediate level groups the outputs from the previous level into groups of {@link #DEFAULT_MAX_FAN_IN} and
* merges them via {@link BlockingQueueuBatchedResultsCursor}. This continues until the number of outputs is within
* the fan-in bound, at which point a final merge produces the output.
*
* This multi-level approach ensures that no single merge task has excessive fan-in, distributing expensive
* combine operations (e.g., HLL sketch unions) across more concurrent workers.
*/
@SuppressWarnings("serial")
private static class MergeCombinePartitioningAction<T> extends RecursiveAction
Expand Down Expand Up @@ -380,8 +383,8 @@ protected void compute()
);
getPool().execute(blockForInputsAction);
} else {
// 2 layer parallel merge done in fjp
LOG.debug("Spawning %s parallel merge-combine tasks for %s sequences", parallelTaskCount, sequences.size());
// multi-level parallel merge done in fjp
LOG.debug("Spawning %s leaf merge-combine tasks for %s sequences", parallelTaskCount, sequences.size());
spawnParallelTasks(parallelTaskCount);
}
}
Expand All @@ -395,27 +398,27 @@ protected void compute()
}
}

private void spawnParallelTasks(int parallelMergeTasks)
private void spawnParallelTasks(int numLeafTasks)
{
List<RecursiveAction> tasks = new ArrayList<>(parallelMergeTasks);
List<MergeCombineActionMetricsAccumulator> taskMetrics = new ArrayList<>(parallelMergeTasks);

List<BlockingQueue<ResultBatch<T>>> intermediaryOutputs = new ArrayList<>(parallelMergeTasks);
List<MergeCombineActionMetricsAccumulator> leafTaskMetrics = new ArrayList<>();
List<MergeCombineActionMetricsAccumulator> intermediateTaskMetrics = new ArrayList<>();

// Step 1: Create leaf-level merge tasks from input sequences
List<BlockingQueue<ResultBatch<T>>> currentLevelQueues = new ArrayList<>();
List<? extends List<Sequence<T>>> partitions =
Lists.partition(sequences, sequences.size() / parallelMergeTasks);
Lists.partition(sequences, sequences.size() / numLeafTasks);

for (List<Sequence<T>> partition : partitions) {
BlockingQueue<ResultBatch<T>> outputQueue = new ArrayBlockingQueue<>(queueSize);
intermediaryOutputs.add(outputQueue);
currentLevelQueues.add(outputQueue);
QueuePusher<T> pusher = new QueuePusher<>(outputQueue, cancellationGizmo, hasTimeout, timeoutAt);

List<BatchedResultsCursor<T>> partitionCursors = new ArrayList<>(sequences.size());
List<BatchedResultsCursor<T>> partitionCursors = new ArrayList<>(partition.size());
for (Sequence<T> s : partition) {
partitionCursors.add(new YielderBatchedResultsCursor<>(new SequenceBatcher<>(s, batchSize), orderingFn));
}
MergeCombineActionMetricsAccumulator partitionAccumulator = new MergeCombineActionMetricsAccumulator();
PrepareMergeCombineInputsAction<T> blockForInputsAction = new PrepareMergeCombineInputsAction<>(
getPool().execute(new PrepareMergeCombineInputsAction<>(
partitionCursors,
pusher,
orderingFn,
Expand All @@ -425,29 +428,64 @@ private void spawnParallelTasks(int parallelMergeTasks)
targetTimeNanos,
partitionAccumulator,
cancellationGizmo
);
tasks.add(blockForInputsAction);
taskMetrics.add(partitionAccumulator);
));
leafTaskMetrics.add(partitionAccumulator);
}

metricsAccumulator.setPartitions(taskMetrics);
// Step 2: Build intermediate merge levels with bounded fan-in until within MAX_FAN_IN
int treeLevel = 1;
while (currentLevelQueues.size() > DEFAULT_MAX_FAN_IN) {
List<BlockingQueue<ResultBatch<T>>> nextLevelQueues = new ArrayList<>();
List<? extends List<BlockingQueue<ResultBatch<T>>>> groups =
Lists.partition(currentLevelQueues, DEFAULT_MAX_FAN_IN);

for (List<BlockingQueue<ResultBatch<T>>> group : groups) {
BlockingQueue<ResultBatch<T>> intermediateOutput = new ArrayBlockingQueue<>(queueSize);
nextLevelQueues.add(intermediateOutput);
QueuePusher<T> pusher = new QueuePusher<>(intermediateOutput, cancellationGizmo, hasTimeout, timeoutAt);

List<BatchedResultsCursor<T>> cursors = new ArrayList<>(group.size());
for (BlockingQueue<ResultBatch<T>> queue : group) {
cursors.add(
new BlockingQueueuBatchedResultsCursor<>(queue, cancellationGizmo, orderingFn, hasTimeout, timeoutAt)
);
}
MergeCombineActionMetricsAccumulator intermediateAccumulator = new MergeCombineActionMetricsAccumulator();
getPool().execute(new PrepareMergeCombineInputsAction<>(
cursors,
pusher,
orderingFn,
combineFn,
yieldAfter,
batchSize,
targetTimeNanos,
intermediateAccumulator,
cancellationGizmo
));
intermediateTaskMetrics.add(intermediateAccumulator);
}

for (RecursiveAction task : tasks) {
getPool().execute(task);
LOG.debug(
"Merge tree level %d: %d intermediate merge tasks (fan-in: %d)",
treeLevel,
nextLevelQueues.size(),
DEFAULT_MAX_FAN_IN
);
currentLevelQueues = nextLevelQueues;
treeLevel++;
}

// Step 3: Final merge to output queue
QueuePusher<T> outputPusher = new QueuePusher<>(out, cancellationGizmo, hasTimeout, timeoutAt);
List<BatchedResultsCursor<T>> intermediaryOutputsCursors = new ArrayList<>(intermediaryOutputs.size());
for (BlockingQueue<ResultBatch<T>> queue : intermediaryOutputs) {
intermediaryOutputsCursors.add(
List<BatchedResultsCursor<T>> finalMergeCursors = new ArrayList<>(currentLevelQueues.size());
for (BlockingQueue<ResultBatch<T>> queue : currentLevelQueues) {
finalMergeCursors.add(
new BlockingQueueuBatchedResultsCursor<>(queue, cancellationGizmo, orderingFn, hasTimeout, timeoutAt)
);
}
MergeCombineActionMetricsAccumulator finalMergeMetrics = new MergeCombineActionMetricsAccumulator();

metricsAccumulator.setMergeMetrics(finalMergeMetrics);
PrepareMergeCombineInputsAction<T> finalMergeAction = new PrepareMergeCombineInputsAction<>(
intermediaryOutputsCursors,
getPool().execute(new PrepareMergeCombineInputsAction<>(
finalMergeCursors,
outputPusher,
orderingFn,
combineFn,
Expand All @@ -456,53 +494,61 @@ private void spawnParallelTasks(int parallelMergeTasks)
targetTimeNanos,
finalMergeMetrics,
cancellationGizmo
);
));

metricsAccumulator.setPartitions(leafTaskMetrics);
metricsAccumulator.setIntermediateMetrics(intermediateTaskMetrics);
metricsAccumulator.setMergeMetrics(finalMergeMetrics);

getPool().execute(finalMergeAction);
LOG.debug(
"Multi-level merge tree: %d leaf tasks, %d intermediate tasks, 1 final merge (%d-way), %d total levels",
leafTaskMetrics.size(),
intermediateTaskMetrics.size(),
currentLevelQueues.size(),
treeLevel + 1
);
}

/**
* Computes maximum number of layer 1 parallel merging tasks given available processors and an estimate of current
* {@link ForkJoinPool} utilization. A return value of 1 or less indicates that a serial merge will be done on
* the pool instead.
* Computes maximum number of leaf-level parallel merging tasks given available processors and an estimate of
* current {@link ForkJoinPool} utilization. A return value of 1 or less indicates that a serial merge will be
* done on the pool instead. With multi-level merge tree, reserves capacity for intermediate merge levels.
*/
private int computeNumTasks()
{
final int runningThreadCount = getPool().getRunningThreadCount();
final int submissionCount = getPool().getQueuedSubmissionCount();

// max is smaller of either:
// - parallelism passed into sequence (number of physical cores by default)
// - pool parallelism (number of physical cores * 1.5 by default)
final int maxParallelism = Math.min(parallelism, getPool().getParallelism());

// we consider 'utilization' to be the number of running threads + submitted tasks that have not yet started
// running, minus 1 for the task that is running this calculation (as it will be replaced with the parallel tasks)
final int utilizationEstimate = runningThreadCount + submissionCount - 1;

// 'computed parallelism' is the remainder of the 'max parallelism' less current 'utilization estimate'
final int computedParallelismForUtilization = maxParallelism - utilizationEstimate;

// try to balance partition size with partition count so we don't end up with layer 2 'final merge' task that has
// significantly more work to do than the layer 1 'parallel' tasks.
final int computedParallelismForSequences = (int) Math.floor(Math.sqrt(sequences.size()));
// With multi-level merge tree (bounded fan-in at each level), we don't need the sqrt(N)
// heuristic that was designed for 2-level trees. Reserve ~1/MAX_FAN_IN of available parallelism
// for intermediate merge levels (each level has roughly 1/MAX_FAN_IN tasks of the level below).
final int computedLeafTasks;
if (computedParallelismForUtilization <= 1) {
computedLeafTasks = 1;
} else {
computedLeafTasks = Math.max(
1,
(int) (computedParallelismForUtilization * ((double) (DEFAULT_MAX_FAN_IN - 1) / DEFAULT_MAX_FAN_IN))
);
}

// compute total number of layer 1 'parallel' tasks, for the utilization parallelism, subtract 1 as the final merge
// task will take the remaining slot
final int computedOptimalParallelism = Math.min(
computedParallelismForSequences,
computedParallelismForUtilization - 1
// Each leaf task should merge at least 2 sequences to be worthwhile
final int computedNumParallelTasks = Math.max(
Math.min(computedLeafTasks, sequences.size() / 2),
1
);

final int computedNumParallelTasks = Math.max(computedOptimalParallelism, 1);

if (LOG.isDebugEnabled()) {
ForkJoinPool pool = getPool();
LOG.debug(
"Computed parallel tasks: [%s]; ForkJoinPool details - sequence parallelism: [%s] "
"Computed parallel tasks: [%s] (max fan-in: [%s]); ForkJoinPool details - sequence parallelism: [%s] "
+ "active threads: [%s] running threads: [%s] queued submissions: [%s] queued tasks: [%s] "
+ "pool parallelism: [%s] pool size: [%s] steal count: [%s]",
computedNumParallelTasks, parallelism,
computedNumParallelTasks, DEFAULT_MAX_FAN_IN, parallelism,
pool.getActiveThreadCount(), runningThreadCount, submissionCount, pool.getQueuedTaskCount(),
pool.getParallelism(), pool.getPoolSize(), pool.getStealCount()
);
Expand Down Expand Up @@ -1410,13 +1456,15 @@ public long getSlowestPartitionInitializedTime()
}

/**
* Holder to accumulate metrics for all work done {@link ParallelMergeCombiningSequence}, containing layer 1 task
* metrics in {@link #partitionMetrics} and final merge task metrics in {@link #mergeMetrics}, in order to compute
* {@link MergeCombineMetrics} after the {@link ParallelMergeCombiningSequence} is completely consumed.
* Holder to accumulate metrics for all work done {@link ParallelMergeCombiningSequence}, containing leaf task
* metrics in {@link #partitionMetrics}, intermediate merge level metrics in {@link #intermediateMetrics}, and
* final merge task metrics in {@link #mergeMetrics}, in order to compute {@link MergeCombineMetrics} after the
* {@link ParallelMergeCombiningSequence} is completely consumed.
*/
static class MergeCombineMetricsAccumulator
{
List<MergeCombineActionMetricsAccumulator> partitionMetrics = Collections.emptyList();
List<MergeCombineActionMetricsAccumulator> intermediateMetrics = Collections.emptyList();
MergeCombineActionMetricsAccumulator mergeMetrics = new MergeCombineActionMetricsAccumulator();

private long totalWallTime;
Expand All @@ -1438,6 +1486,11 @@ void setPartitions(List<MergeCombineActionMetricsAccumulator> partitionMetrics)
this.partitionMetrics = partitionMetrics;
}

void setIntermediateMetrics(List<MergeCombineActionMetricsAccumulator> intermediateMetrics)
{
this.intermediateMetrics = intermediateMetrics;
}

void setTotalWallTime(long time)
{
this.totalWallTime = time;
Expand All @@ -1447,14 +1500,14 @@ MergeCombineMetrics build()
{
long numInputRows = 0;
long cpuTimeNanos = 0;
// 1 partition task, 1 layer two prepare merge inputs task, 1 layer one prepare merge inputs task for each
// partition
long totalPoolTasks = 1 + 1 + partitionMetrics.size();
// 1 partition action, 1 final prepare-merge-inputs action, 1 prepare-merge-inputs per leaf partition,
// 1 prepare-merge-inputs per intermediate merge task
long totalPoolTasks = 1 + 1 + partitionMetrics.size() + intermediateMetrics.size();

long fastestPartInitialized = !partitionMetrics.isEmpty() ? Long.MAX_VALUE : mergeMetrics.getPartitionInitializedtime();
long slowestPartInitialied = !partitionMetrics.isEmpty() ? Long.MIN_VALUE : mergeMetrics.getPartitionInitializedtime();

// accumulate input row count, cpu time, and total number of tasks from each partition
// accumulate input row count, cpu time, and total number of tasks from each leaf partition
for (MergeCombineActionMetricsAccumulator partition : partitionMetrics) {
numInputRows += partition.getInputRows();
cpuTimeNanos += partition.getTotalCpuTimeNanos();
Expand All @@ -1466,8 +1519,15 @@ MergeCombineMetrics build()
slowestPartInitialied = partition.getPartitionInitializedtime();
}
}
// if serial merge done, only mergeMetrics is populated, get input rows from there instead. otherwise, ignore the
// value as it is only the number of intermediary input rows to the layer 2 task

// accumulate cpu time and task count from intermediate merge levels
for (MergeCombineActionMetricsAccumulator intermediate : intermediateMetrics) {
cpuTimeNanos += intermediate.getTotalCpuTimeNanos();
totalPoolTasks += intermediate.getTaskCount();
}

// if serial merge done, only mergeMetrics is populated, get input rows from there instead.
// otherwise, ignore the value as it is only the number of intermediary input rows to the final task
if (partitionMetrics.isEmpty()) {
numInputRows = mergeMetrics.getInputRows();
}
Expand Down
Loading
Loading