Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 69 additions & 22 deletions third_party/xla/xla/service/buffer_assignment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include <cstdint>
#include <deque>
#include <iterator>
#include <limits>
#include <memory>
#include <optional>
#include <ostream>
Expand Down Expand Up @@ -79,6 +80,30 @@ using absl::StrAppendFormat;
using memory_space_assignment::PresetAssignments;
using ::tsl::strings::HumanReadableNumBytes;

int64_t NextPowerOfTwo(int64_t value) {
CHECK_GT(value, 0);
if ((value & (value - 1)) == 0) {
return value;
}
uint64_t rounded = 1;
while (rounded < static_cast<uint64_t>(value)) {
rounded <<= 1;
}
CHECK_LE(rounded, static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
return static_cast<int64_t>(rounded);
}

bool ShouldPadBufferStorage(const HloValue& value,
const HloAliasAnalysis& alias_analysis) {
const HloInstruction* instruction = value.instruction();
const bool is_entry_parameter =
instruction->opcode() == HloOpcode::kParameter &&
instruction->parent() == instruction->GetModule()->entry_computation();
return value.shape().has_dynamic_expr() &&
instruction->opcode() != HloOpcode::kConstant && !is_entry_parameter &&
!value.shape().IsTuple() && !alias_analysis.ValueLivesOut(value);
}

absl::flat_hash_map<int64_t, const HloInstruction*> BuildIdToHloInstructionMap(
const HloModule* module) {
// Build a map from a unique_id to corresponding HloInstruction in the module.
Expand Down Expand Up @@ -607,9 +632,12 @@ BufferAllocation* BufferAssignment::NewEmptyAllocation(
}

BufferAllocation* BufferAssignment::NewAllocation(const HloBuffer& buffer,
int64_t size) {
int64_t size,
std::optional<int64_t>
assigned_size) {
BufferAllocation* allocation = NewEmptyAllocation(size, buffer.color());
AddAssignment(allocation, buffer, /*offset=*/0, size);
AddAssignment(allocation, buffer, /*offset=*/0,
assigned_size.value_or(size));
allocation->peak_buffers_.push_back(buffer.values()[0]);
return allocation;
}
Expand Down Expand Up @@ -815,7 +843,7 @@ absl::StatusOr<int64_t> BufferAssignment::ComputeTotalFragmentationBytes()
TF_RETURN_IF_ERROR(schedule.Verify());
TF_ASSIGN_OR_RETURN(
const int64_t min_size,
HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_));
HeapSimulator::MinimumMemoryForModule(schedule, storage_buffer_size_));
return stats_.total_allocation_bytes - min_size;
}
return -1;
Expand Down Expand Up @@ -1167,12 +1195,12 @@ BufferAssignmentProto BufferAssignment::ToProto() const {
BufferAssignmentProto proto;
// NOTE: DataflowAnalysis state is serialized here in BufferAssignment,
// because we need to do the HasAllocation check for each buffer. Otherwise
// the buffer_size_ call might fail for some backends.
// the logical_buffer_size_ call might fail for some backends.
const HloDataflowAnalysis& dataflow = this->dataflow_analysis();
for (BufferValue::Id id = 0; id < dataflow.values().size(); id++) {
auto& value = dataflow.values().at(id);
if (HasAllocation(*value)) {
LogicalBufferProto proto_buffer = value->ToProto(buffer_size_);
LogicalBufferProto proto_buffer = value->ToProto(logical_buffer_size_);
proto.add_logical_buffers()->Swap(&proto_buffer);

// Fill buffer aliases.
Expand Down Expand Up @@ -1220,9 +1248,12 @@ absl::StatusOr<std::unique_ptr<BufferAssignment>> BufferAssignment::FromProto(
id_to_logical_buffer,
BuildIdToLogicalBufferMap(proto, id_to_hlo_instruction, alias_analysis));

BufferValue::SizeFunction logical_buffer_size = buffer_size;
BufferValue::SizeFunction storage_buffer_size = buffer_size;
std::unique_ptr<BufferAssignment> buffer_assignment =
absl::WrapUnique(new BufferAssignment(
module, /*hlo_ordering=*/nullptr, std::move(buffer_size),
module, /*hlo_ordering=*/nullptr, std::move(logical_buffer_size),
std::move(storage_buffer_size),
/*color_alignment=*/nullptr, std::move(alias_analysis),
/*hlo_live_range=*/nullptr));

Expand Down Expand Up @@ -1378,7 +1409,7 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
<< "buffer " << hlo_buffer << " already has an allocation assigned.";

VLOG(4) << "Trying to assign " << hlo_buffer << " size "
<< assignment->HloBufferSize(hlo_buffer)
<< assignment->HloBufferStorageSize(hlo_buffer)
<< " to allocation: " << *allocation;

if (hlo_buffer.color() != allocation->color()) {
Expand All @@ -1387,9 +1418,9 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
return false;
}

if (assignment->HloBufferSize(hlo_buffer) > allocation->size()) {
if (assignment->HloBufferStorageSize(hlo_buffer) > allocation->size()) {
VLOG(4) << "Can't assign: buffer is larger than allocation ("
<< assignment->HloBufferSize(hlo_buffer) << " > "
<< assignment->HloBufferStorageSize(hlo_buffer) << " > "
<< allocation->size() << ")";
return false;
}
Expand Down Expand Up @@ -1495,12 +1526,14 @@ absl::Status BufferAssigner::AssignSingleHloBuffer(
buffers_to_assign_sequentially,
std::vector<BufferAllocation::Index>* allocation_indices,
BufferAssignment* assignment) {
const int64_t buffer_size = assignment->HloBufferSize(*hlo_buffer);
const int64_t logical_buffer_size = assignment->HloBufferSize(*hlo_buffer);
const int64_t storage_buffer_size =
assignment->HloBufferStorageSize(*hlo_buffer);
for (const HloValue* value : hlo_buffer->values()) {
if (value->instruction()->opcode() == HloOpcode::kConstant) {
if (allocate_buffers_for_constants_) {
BufferAllocation* allocation =
assignment->NewAllocation(*hlo_buffer, buffer_size);
assignment->NewAllocation(*hlo_buffer, logical_buffer_size);
allocation->set_constant(true);
VLOG(3) << "New allocation #" << allocation->index() << " for constant "
<< *hlo_buffer << " value ptr: " << value;
Expand All @@ -1523,7 +1556,7 @@ absl::Status BufferAssigner::AssignSingleHloBuffer(
// computations do not need special allocations because they live inside
// callers.
BufferAllocation* allocation =
assignment->NewAllocation(*hlo_buffer, buffer_size);
assignment->NewAllocation(*hlo_buffer, logical_buffer_size);

allocation->set_entry_computation_parameter(
instruction->parameter_number(), value->index(), parameter_has_alias);
Expand All @@ -1538,7 +1571,7 @@ absl::Status BufferAssigner::AssignSingleHloBuffer(

if (is_thread_local) {
BufferAllocation* allocation =
assignment->NewAllocation(*hlo_buffer, buffer_size);
assignment->NewAllocation(*hlo_buffer, logical_buffer_size);
allocation->set_is_thread_local(true);
VLOG(3) << "New allocation #" << allocation->index()
<< " for thread-local: " << *hlo_buffer;
Expand All @@ -1548,7 +1581,7 @@ absl::Status BufferAssigner::AssignSingleHloBuffer(
for (const HloValue* value : hlo_buffer->values()) {
if (value->shape().IsTuple()) {
BufferAllocation* allocation =
assignment->NewAllocation(*hlo_buffer, buffer_size);
assignment->NewAllocation(*hlo_buffer, logical_buffer_size);
allocation->set_is_tuple(true);
VLOG(3) << "New allocation #" << allocation->index()
<< " for tuple-shaped buffer: " << *hlo_buffer;
Expand Down Expand Up @@ -1616,7 +1649,8 @@ absl::Status BufferAssigner::AssignSingleHloBuffer(

if (!assignment->HasAllocation(*hlo_buffer)) {
BufferAllocation* allocation =
assignment->NewAllocation(*hlo_buffer, buffer_size);
assignment->NewAllocation(*hlo_buffer, storage_buffer_size,
logical_buffer_size);
allocation_indices->push_back(allocation->index());
VLOG(3) << "New allocation #" << allocation->index()
<< " for: " << *hlo_buffer;
Expand Down Expand Up @@ -1702,8 +1736,8 @@ absl::Status BufferAssigner::AssignBuffersForComputations(
sorted_buffers, [&post_order_position, &alias_analysis, assignment](
const HloBuffer* a, const HloBuffer* b) {
// Primary sort is by decreasing buffer size.
const int64_t a_size = assignment->HloBufferSize(*a);
const int64_t b_size = assignment->HloBufferSize(*b);
const int64_t a_size = assignment->HloBufferStorageSize(*a);
const int64_t b_size = assignment->HloBufferStorageSize(*b);
if (a_size != b_size) {
return a_size > b_size; // use ">" for decreasing size.
}
Expand Down Expand Up @@ -1922,7 +1956,7 @@ absl::Status BufferAssigner::AssignBuffersWithSequentialOrdering(
HeapSimulator::Run(
get_heap_algorithm(alignment), *private_stack_computation,
*instruction_sequence, assignment->alias_analysis(),
assignment->buffer_size_, &schedule, options));
assignment->storage_buffer_size_, &schedule, options));
AssignBuffersFromHeapSimulator(result, assignment, color,
isolation_options);
}
Expand All @@ -1933,7 +1967,7 @@ absl::Status BufferAssigner::AssignBuffersWithSequentialOrdering(
HeapSimulator::Run(get_heap_algorithm(alignment),
assignment->module(), schedule,
assignment->alias_analysis(),
assignment->buffer_size_, options));
assignment->storage_buffer_size_, options));
AssignBuffersFromHeapSimulator(result, assignment, color,
isolation_options);
}
Expand Down Expand Up @@ -1967,7 +2001,7 @@ absl::Status BufferAssigner::AssignBuffersWithSequentialOrdering(
HeapSimulator::Run(get_heap_algorithm(alignment), *computation,
*instruction_sequence,
assignment->alias_analysis(),
assignment->buffer_size_, options));
assignment->storage_buffer_size_, options));
AssignBuffersFromHeapSimulator(result, assignment, color,
isolation_options);
}
Expand Down Expand Up @@ -2185,7 +2219,8 @@ void BufferAssigner::AssignBuffersFromHeapSimulator(
assignment->NewEmptyAllocation(heap_result.heap_size, color);

for (const auto& [value, chunk] : heap_result.chunk_map) {
assignment->AddAssignment(allocation, *value, chunk.offset, chunk.size);
assignment->AddAssignment(allocation, *value, chunk.offset,
assignment->logical_buffer_size_(*value));
}
allocation->peak_buffers_ =
ComputePeakMemoryLogicalBuffers(*allocation, result.debug_trace);
Expand Down Expand Up @@ -2232,10 +2267,22 @@ BufferAssigner::CreateAssignment(
VLOG(1) << "Number of buffers to assign: "
<< alias_analysis->buffers().size();

BufferValue::SizeFunction logical_buffer_size = buffer_size;
BufferValue::SizeFunction storage_buffer_size =
[&alias_analysis = *alias_analysis,
logical_buffer_size](const HloValue& value) -> int64_t {
int64_t logical_size = logical_buffer_size(value);
if (logical_size <= 1 || !ShouldPadBufferStorage(value, alias_analysis)) {
return logical_size;
}
return NextPowerOfTwo(logical_size);
};

// Can't use std::make_unique because BufferAssignment constructor is
// private.
std::unique_ptr<BufferAssignment> assignment(new BufferAssignment(
module, std::move(hlo_ordering), std::move(buffer_size),
module, std::move(hlo_ordering), std::move(logical_buffer_size),
std::move(storage_buffer_size),
std::move(color_alignment), std::move(alias_analysis),
std::move(hlo_live_range)));

Expand Down
32 changes: 26 additions & 6 deletions third_party/xla/xla/service/buffer_assignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -554,13 +554,15 @@ class BufferAssignment {

BufferAssignment(const HloModule* module,
std::unique_ptr<HloOrdering> hlo_ordering,
BufferValue::SizeFunction buffer_size,
BufferValue::SizeFunction logical_buffer_size,
BufferValue::SizeFunction storage_buffer_size,
LogicalBuffer::AlignmentFunction color_alignment,
std::unique_ptr<HloAliasAnalysis> alias_analysis,
std::unique_ptr<HloLiveRange> hlo_live_range)
: module_(module),
hlo_ordering_(std::move(hlo_ordering)),
buffer_size_(std::move(buffer_size)),
logical_buffer_size_(std::move(logical_buffer_size)),
storage_buffer_size_(std::move(storage_buffer_size)),
color_alignment_(std::move(color_alignment)),
alias_analysis_(std::move(alias_analysis)),
hlo_live_range_(std::move(hlo_live_range)) {
Expand All @@ -579,7 +581,9 @@ class BufferAssignment {

// Helper that calls NewEmptyAllocation and AddAssignment in one call,
// creating an allocation containing a single LogicalBuffer.
BufferAllocation* NewAllocation(const HloBuffer& buffer, int64_t size);
BufferAllocation* NewAllocation(const HloBuffer& buffer, int64_t size,
std::optional<int64_t> assigned_size =
std::nullopt);

// Adds a LogicalBuffer to the set assigned to the given allocation.
void AddAssignment(BufferAllocation* allocation, const HloBuffer& buffer,
Expand All @@ -600,12 +604,23 @@ class BufferAssignment {
if (iter != cached_buffer_sizes_.end()) return iter->second;
int64_t result = 0;
for (const HloValue* value : buffer.values()) {
result = std::max(result, buffer_size_(*value));
result = std::max(result, logical_buffer_size_(*value));
}
cached_buffer_sizes_.insert({buffer.id(), result});
return result;
}

int64_t HloBufferStorageSize(const HloBuffer& buffer) {
auto iter = cached_storage_buffer_sizes_.find(buffer.id());
if (iter != cached_storage_buffer_sizes_.end()) return iter->second;
int64_t result = 0;
for (const HloValue* value : buffer.values()) {
result = std::max(result, storage_buffer_size_(*value));
}
cached_storage_buffer_sizes_.insert({buffer.id(), result});
return result;
}

// Combines allocations of temporary buffers into one big BufferAllocation.
void CombineTempAllocations(
const absl::flat_hash_set<BufferValue::Color>& private_stack_colors,
Expand Down Expand Up @@ -633,8 +648,12 @@ class BufferAssignment {

const std::unique_ptr<HloOrdering> hlo_ordering_;

// Function which returns the buffer size for a given logical buffer (shape).
BufferValue::SizeFunction buffer_size_;
// Function which returns the logical buffer size for a given buffer value.
BufferValue::SizeFunction logical_buffer_size_;

// Function which returns the storage size to reserve for a given buffer
// value. This may be larger than the logical size when we pad allocations.
BufferValue::SizeFunction storage_buffer_size_;

// Function which returns the alignment for a given logical buffer color.
LogicalBuffer::AlignmentFunction color_alignment_;
Expand All @@ -646,6 +665,7 @@ class BufferAssignment {
Stats stats_;

absl::flat_hash_map<HloBuffer::Id, int64_t> cached_buffer_sizes_;
absl::flat_hash_map<HloBuffer::Id, int64_t> cached_storage_buffer_sizes_;

BufferAssignment(const BufferAssignment&) = delete;
BufferAssignment& operator=(const BufferAssignment&) = delete;
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/cpu/cpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ absl::StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
stream->parent()->device_ordinal(), allocation_size));
result_buffer = allocated_buffer.Release();
MaybeOwningDeviceMemory& registered_buffer = buffers[buffer_index];
CHECK_EQ(result_buffer.size(),
CHECK_LE(result_buffer.size(),
registered_buffer.AsDeviceMemoryBase().size());
std::memcpy(/*dest=*/result_buffer.opaque(),
/*src=*/registered_buffer.AsDeviceMemoryBase().opaque(),
Expand Down