From 9459b4a650fca2d4b3f741422473fe98d8a1d7ba Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Thu, 2 Apr 2026 22:06:46 +0100 Subject: [PATCH 1/4] Pad temp buffer storage to power-of-two sizes --- .../xla/xla/service/buffer_assignment.cc | 85 ++++++++++++++----- .../xla/xla/service/buffer_assignment.h | 32 +++++-- 2 files changed, 90 insertions(+), 27 deletions(-) diff --git a/third_party/xla/xla/service/buffer_assignment.cc b/third_party/xla/xla/service/buffer_assignment.cc index 487c8ce7cbb96c..d91f8ff1a07fe9 100644 --- a/third_party/xla/xla/service/buffer_assignment.cc +++ b/third_party/xla/xla/service/buffer_assignment.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -79,6 +80,29 @@ using absl::StrAppendFormat; using memory_space_assignment::PresetAssignments; using ::tsl::strings::HumanReadableNumBytes; +int64_t NextPowerOfTwo(int64_t value) { + CHECK_GT(value, 0); + if ((value & (value - 1)) == 0) { + return value; + } + uint64_t rounded = 1; + while (rounded < static_cast(value)) { + rounded <<= 1; + } + CHECK_LE(rounded, static_cast(std::numeric_limits::max())); + return static_cast(rounded); +} + +bool ShouldPadBufferStorage(const HloValue& value, + const HloAliasAnalysis& alias_analysis) { + const HloInstruction* instruction = value.instruction(); + const bool is_entry_parameter = + instruction->opcode() == HloOpcode::kParameter && + instruction->parent() == instruction->GetModule()->entry_computation(); + return instruction->opcode() != HloOpcode::kConstant && !is_entry_parameter && + !value.shape().IsTuple() && !alias_analysis.ValueLivesOut(value); +} + absl::flat_hash_map BuildIdToHloInstructionMap( const HloModule* module) { // Build a map from a unique_id to corresponding HloInstruction in the module. @@ -607,9 +631,12 @@ BufferAllocation* BufferAssignment::NewEmptyAllocation( } BufferAllocation* BufferAssignment::NewAllocation(const HloBuffer& buffer, - int64_t size) { + int64_t size, + std::optional + assigned_size) { BufferAllocation* allocation = NewEmptyAllocation(size, buffer.color()); - AddAssignment(allocation, buffer, /*offset=*/0, size); + AddAssignment(allocation, buffer, /*offset=*/0, + assigned_size.value_or(size)); allocation->peak_buffers_.push_back(buffer.values()[0]); return allocation; } @@ -815,7 +842,7 @@ absl::StatusOr BufferAssignment::ComputeTotalFragmentationBytes() TF_RETURN_IF_ERROR(schedule.Verify()); TF_ASSIGN_OR_RETURN( const int64_t min_size, - HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_)); + HeapSimulator::MinimumMemoryForModule(schedule, storage_buffer_size_)); return stats_.total_allocation_bytes - min_size; } return -1; @@ -1167,12 +1194,12 @@ BufferAssignmentProto BufferAssignment::ToProto() const { BufferAssignmentProto proto; // NOTE: DataflowAnalysis state is serialized here in BufferAssignment, // because we need to do the HasAllocation check for each buffer. Otherwise - // the buffer_size_ call might fail for some backends. + // the logical_buffer_size_ call might fail for some backends. const HloDataflowAnalysis& dataflow = this->dataflow_analysis(); for (BufferValue::Id id = 0; id < dataflow.values().size(); id++) { auto& value = dataflow.values().at(id); if (HasAllocation(*value)) { - LogicalBufferProto proto_buffer = value->ToProto(buffer_size_); + LogicalBufferProto proto_buffer = value->ToProto(logical_buffer_size_); proto.add_logical_buffers()->Swap(&proto_buffer); // Fill buffer aliases. @@ -1378,7 +1405,7 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, << "buffer " << hlo_buffer << " already has an allocation assigned."; VLOG(4) << "Trying to assign " << hlo_buffer << " size " - << assignment->HloBufferSize(hlo_buffer) + << assignment->HloBufferStorageSize(hlo_buffer) << " to allocation: " << *allocation; if (hlo_buffer.color() != allocation->color()) { @@ -1387,9 +1414,9 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, return false; } - if (assignment->HloBufferSize(hlo_buffer) > allocation->size()) { + if (assignment->HloBufferStorageSize(hlo_buffer) > allocation->size()) { VLOG(4) << "Can't assign: buffer is larger than allocation (" - << assignment->HloBufferSize(hlo_buffer) << " > " + << assignment->HloBufferStorageSize(hlo_buffer) << " > " << allocation->size() << ")"; return false; } @@ -1495,12 +1522,14 @@ absl::Status BufferAssigner::AssignSingleHloBuffer( buffers_to_assign_sequentially, std::vector* allocation_indices, BufferAssignment* assignment) { - const int64_t buffer_size = assignment->HloBufferSize(*hlo_buffer); + const int64_t logical_buffer_size = assignment->HloBufferSize(*hlo_buffer); + const int64_t storage_buffer_size = + assignment->HloBufferStorageSize(*hlo_buffer); for (const HloValue* value : hlo_buffer->values()) { if (value->instruction()->opcode() == HloOpcode::kConstant) { if (allocate_buffers_for_constants_) { BufferAllocation* allocation = - assignment->NewAllocation(*hlo_buffer, buffer_size); + assignment->NewAllocation(*hlo_buffer, logical_buffer_size); allocation->set_constant(true); VLOG(3) << "New allocation #" << allocation->index() << " for constant " << *hlo_buffer << " value ptr: " << value; @@ -1523,7 +1552,7 @@ absl::Status BufferAssigner::AssignSingleHloBuffer( // computations do not need special allocations because they live inside // callers. BufferAllocation* allocation = - assignment->NewAllocation(*hlo_buffer, buffer_size); + assignment->NewAllocation(*hlo_buffer, logical_buffer_size); allocation->set_entry_computation_parameter( instruction->parameter_number(), value->index(), parameter_has_alias); @@ -1538,7 +1567,7 @@ absl::Status BufferAssigner::AssignSingleHloBuffer( if (is_thread_local) { BufferAllocation* allocation = - assignment->NewAllocation(*hlo_buffer, buffer_size); + assignment->NewAllocation(*hlo_buffer, logical_buffer_size); allocation->set_is_thread_local(true); VLOG(3) << "New allocation #" << allocation->index() << " for thread-local: " << *hlo_buffer; @@ -1548,7 +1577,7 @@ absl::Status BufferAssigner::AssignSingleHloBuffer( for (const HloValue* value : hlo_buffer->values()) { if (value->shape().IsTuple()) { BufferAllocation* allocation = - assignment->NewAllocation(*hlo_buffer, buffer_size); + assignment->NewAllocation(*hlo_buffer, logical_buffer_size); allocation->set_is_tuple(true); VLOG(3) << "New allocation #" << allocation->index() << " for tuple-shaped buffer: " << *hlo_buffer; @@ -1616,7 +1645,8 @@ absl::Status BufferAssigner::AssignSingleHloBuffer( if (!assignment->HasAllocation(*hlo_buffer)) { BufferAllocation* allocation = - assignment->NewAllocation(*hlo_buffer, buffer_size); + assignment->NewAllocation(*hlo_buffer, storage_buffer_size, + logical_buffer_size); allocation_indices->push_back(allocation->index()); VLOG(3) << "New allocation #" << allocation->index() << " for: " << *hlo_buffer; @@ -1702,8 +1732,8 @@ absl::Status BufferAssigner::AssignBuffersForComputations( sorted_buffers, [&post_order_position, &alias_analysis, assignment]( const HloBuffer* a, const HloBuffer* b) { // Primary sort is by decreasing buffer size. - const int64_t a_size = assignment->HloBufferSize(*a); - const int64_t b_size = assignment->HloBufferSize(*b); + const int64_t a_size = assignment->HloBufferStorageSize(*a); + const int64_t b_size = assignment->HloBufferStorageSize(*b); if (a_size != b_size) { return a_size > b_size; // use ">" for decreasing size. } @@ -1922,7 +1952,7 @@ absl::Status BufferAssigner::AssignBuffersWithSequentialOrdering( HeapSimulator::Run( get_heap_algorithm(alignment), *private_stack_computation, *instruction_sequence, assignment->alias_analysis(), - assignment->buffer_size_, &schedule, options)); + assignment->storage_buffer_size_, &schedule, options)); AssignBuffersFromHeapSimulator(result, assignment, color, isolation_options); } @@ -1933,7 +1963,7 @@ absl::Status BufferAssigner::AssignBuffersWithSequentialOrdering( HeapSimulator::Run(get_heap_algorithm(alignment), assignment->module(), schedule, assignment->alias_analysis(), - assignment->buffer_size_, options)); + assignment->storage_buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, color, isolation_options); } @@ -1967,7 +1997,7 @@ absl::Status BufferAssigner::AssignBuffersWithSequentialOrdering( HeapSimulator::Run(get_heap_algorithm(alignment), *computation, *instruction_sequence, assignment->alias_analysis(), - assignment->buffer_size_, options)); + assignment->storage_buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, color, isolation_options); } @@ -2185,7 +2215,8 @@ void BufferAssigner::AssignBuffersFromHeapSimulator( assignment->NewEmptyAllocation(heap_result.heap_size, color); for (const auto& [value, chunk] : heap_result.chunk_map) { - assignment->AddAssignment(allocation, *value, chunk.offset, chunk.size); + assignment->AddAssignment(allocation, *value, chunk.offset, + assignment->logical_buffer_size_(*value)); } allocation->peak_buffers_ = ComputePeakMemoryLogicalBuffers(*allocation, result.debug_trace); @@ -2232,10 +2263,22 @@ BufferAssigner::CreateAssignment( VLOG(1) << "Number of buffers to assign: " << alias_analysis->buffers().size(); + BufferValue::SizeFunction logical_buffer_size = buffer_size; + BufferValue::SizeFunction storage_buffer_size = + [&alias_analysis = *alias_analysis, + logical_buffer_size](const HloValue& value) -> int64_t { + int64_t logical_size = logical_buffer_size(value); + if (logical_size <= 1 || !ShouldPadBufferStorage(value, alias_analysis)) { + return logical_size; + } + return NextPowerOfTwo(logical_size); + }; + // Can't use std::make_unique because BufferAssignment constructor is // private. std::unique_ptr assignment(new BufferAssignment( - module, std::move(hlo_ordering), std::move(buffer_size), + module, std::move(hlo_ordering), std::move(logical_buffer_size), + std::move(storage_buffer_size), std::move(color_alignment), std::move(alias_analysis), std::move(hlo_live_range))); diff --git a/third_party/xla/xla/service/buffer_assignment.h b/third_party/xla/xla/service/buffer_assignment.h index 762ce70e357f9f..725460536dfffb 100644 --- a/third_party/xla/xla/service/buffer_assignment.h +++ b/third_party/xla/xla/service/buffer_assignment.h @@ -554,13 +554,15 @@ class BufferAssignment { BufferAssignment(const HloModule* module, std::unique_ptr hlo_ordering, - BufferValue::SizeFunction buffer_size, + BufferValue::SizeFunction logical_buffer_size, + BufferValue::SizeFunction storage_buffer_size, LogicalBuffer::AlignmentFunction color_alignment, std::unique_ptr alias_analysis, std::unique_ptr hlo_live_range) : module_(module), hlo_ordering_(std::move(hlo_ordering)), - buffer_size_(std::move(buffer_size)), + logical_buffer_size_(std::move(logical_buffer_size)), + storage_buffer_size_(std::move(storage_buffer_size)), color_alignment_(std::move(color_alignment)), alias_analysis_(std::move(alias_analysis)), hlo_live_range_(std::move(hlo_live_range)) { @@ -579,7 +581,9 @@ class BufferAssignment { // Helper that calls NewEmptyAllocation and AddAssignment in one call, // creating an allocation containing a single LogicalBuffer. - BufferAllocation* NewAllocation(const HloBuffer& buffer, int64_t size); + BufferAllocation* NewAllocation(const HloBuffer& buffer, int64_t size, + std::optional assigned_size = + std::nullopt); // Adds a LogicalBuffer to the set assigned to the given allocation. void AddAssignment(BufferAllocation* allocation, const HloBuffer& buffer, @@ -600,12 +604,23 @@ class BufferAssignment { if (iter != cached_buffer_sizes_.end()) return iter->second; int64_t result = 0; for (const HloValue* value : buffer.values()) { - result = std::max(result, buffer_size_(*value)); + result = std::max(result, logical_buffer_size_(*value)); } cached_buffer_sizes_.insert({buffer.id(), result}); return result; } + int64_t HloBufferStorageSize(const HloBuffer& buffer) { + auto iter = cached_storage_buffer_sizes_.find(buffer.id()); + if (iter != cached_storage_buffer_sizes_.end()) return iter->second; + int64_t result = 0; + for (const HloValue* value : buffer.values()) { + result = std::max(result, storage_buffer_size_(*value)); + } + cached_storage_buffer_sizes_.insert({buffer.id(), result}); + return result; + } + // Combines allocations of temporary buffers into one big BufferAllocation. void CombineTempAllocations( const absl::flat_hash_set& private_stack_colors, @@ -633,8 +648,12 @@ class BufferAssignment { const std::unique_ptr hlo_ordering_; - // Function which returns the buffer size for a given logical buffer (shape). - BufferValue::SizeFunction buffer_size_; + // Function which returns the logical buffer size for a given buffer value. + BufferValue::SizeFunction logical_buffer_size_; + + // Function which returns the storage size to reserve for a given buffer + // value. This may be larger than the logical size when we pad allocations. + BufferValue::SizeFunction storage_buffer_size_; // Function which returns the alignment for a given logical buffer color. LogicalBuffer::AlignmentFunction color_alignment_; @@ -646,6 +665,7 @@ class BufferAssignment { Stats stats_; absl::flat_hash_map cached_buffer_sizes_; + absl::flat_hash_map cached_storage_buffer_sizes_; BufferAssignment(const BufferAssignment&) = delete; BufferAssignment& operator=(const BufferAssignment&) = delete; From 6776abaa2a73d0a231c3ec873a2a97b0cb6dd782 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Thu, 2 Apr 2026 22:17:46 +0100 Subject: [PATCH 2/4] Allow padded temp buffer copies to logical outputs --- third_party/xla/xla/service/cpu/cpu_executable.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/cpu/cpu_executable.cc b/third_party/xla/xla/service/cpu/cpu_executable.cc index de9ffc2b78eb80..38471c3e3e1cc1 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.cc +++ b/third_party/xla/xla/service/cpu/cpu_executable.cc @@ -450,7 +450,7 @@ absl::StatusOr CpuExecutable::CreateResultShapedBuffer( stream->parent()->device_ordinal(), allocation_size)); result_buffer = allocated_buffer.Release(); MaybeOwningDeviceMemory& registered_buffer = buffers[buffer_index]; - CHECK_EQ(result_buffer.size(), + CHECK_LE(result_buffer.size(), registered_buffer.AsDeviceMemoryBase().size()); std::memcpy(/*dest=*/result_buffer.opaque(), /*src=*/registered_buffer.AsDeviceMemoryBase().opaque(), From 24c9904c4ecd854bdb7c70f8dc77c4a141521aed Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Thu, 2 Apr 2026 22:28:05 +0100 Subject: [PATCH 3/4] Fix buffer assignment constructor for padded storage --- third_party/xla/xla/service/buffer_assignment.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/buffer_assignment.cc b/third_party/xla/xla/service/buffer_assignment.cc index d91f8ff1a07fe9..09b019488fa3d5 100644 --- a/third_party/xla/xla/service/buffer_assignment.cc +++ b/third_party/xla/xla/service/buffer_assignment.cc @@ -1247,9 +1247,12 @@ absl::StatusOr> BufferAssignment::FromProto( id_to_logical_buffer, BuildIdToLogicalBufferMap(proto, id_to_hlo_instruction, alias_analysis)); + BufferValue::SizeFunction logical_buffer_size = buffer_size; + BufferValue::SizeFunction storage_buffer_size = buffer_size; std::unique_ptr buffer_assignment = absl::WrapUnique(new BufferAssignment( - module, /*hlo_ordering=*/nullptr, std::move(buffer_size), + module, /*hlo_ordering=*/nullptr, std::move(logical_buffer_size), + std::move(storage_buffer_size), /*color_alignment=*/nullptr, std::move(alias_analysis), /*hlo_live_range=*/nullptr)); From c945ab8e959ca2bfdb7f030af6dc74edce436c0f Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Thu, 2 Apr 2026 22:40:09 +0100 Subject: [PATCH 4/4] Pad storage only for dynamic-expression buffers --- third_party/xla/xla/service/buffer_assignment.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/buffer_assignment.cc b/third_party/xla/xla/service/buffer_assignment.cc index 09b019488fa3d5..208b08f202fcac 100644 --- a/third_party/xla/xla/service/buffer_assignment.cc +++ b/third_party/xla/xla/service/buffer_assignment.cc @@ -99,7 +99,8 @@ bool ShouldPadBufferStorage(const HloValue& value, const bool is_entry_parameter = instruction->opcode() == HloOpcode::kParameter && instruction->parent() == instruction->GetModule()->entry_computation(); - return instruction->opcode() != HloOpcode::kConstant && !is_entry_parameter && + return value.shape().has_dynamic_expr() && + instruction->opcode() != HloOpcode::kConstant && !is_entry_parameter && !value.shape().IsTuple() && !alias_analysis.ValueLivesOut(value); }