From 3d5eae5c184e147b955b8a8ab869cb576bef5d57 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Mon, 30 Mar 2026 15:56:49 +0100 Subject: [PATCH 1/2] Do not output-fuse dynamic dots on CPU --- .../xla/service/cpu/cpu_instruction_fusion.cc | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc index bd26c1c87621cd..507a1fa9ae13a4 100644 --- a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc +++ b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc @@ -59,6 +59,23 @@ bool IsNonComplexNonBatchedMatrixVectorDot(const HloInstruction* hlo) { hlo->dot_dimension_numbers().lhs_batch_dimensions_size() == 0; } +bool HasDynamicDimensions(const Shape& shape) { + for (int64_t i = 0; i < shape.dimensions().size(); ++i) { + if (shape.is_dynamic_dimension(i) || + (shape.expressions(i) && shape.expressions(i)->is_dynamic())) { + return true; + } + } + return false; +} + +bool IsDynamicDot(const HloInstruction* hlo) { + return hlo->opcode() == HloOpcode::kDot && + (HasDynamicDimensions(hlo->shape()) || + HasDynamicDimensions(hlo->operand(0)->shape()) || + HasDynamicDimensions(hlo->operand(1)->shape())); +} + bool HasExactlyOneUse(const HloInstruction& hlo_instr) { return hlo_instr.user_count() == 1 && absl::c_count(hlo_instr.users().front()->operands(), &hlo_instr) == 1; @@ -142,6 +159,10 @@ FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return FusionDecision::Forbid("Don't fuse large constants."); } + if (IsDynamicDot(producer) || IsDynamicDot(consumer)) { + return FusionDecision::Forbid("Do not fuse dynamic dots on CPU."); + } + if (CanBeOutputFused(producer, consumer)) { VLOG(2) << "Fusion OK: Can create output fusion."; return FusionDecision::Allow(); From c0d1b062894df2dc73830f929cf95e98b60f24a0 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Mon, 30 Mar 2026 17:37:54 +0100 Subject: [PATCH 2/2] Update third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc index 507a1fa9ae13a4..54c34104f4b7b0 100644 --- a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc +++ b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc @@ -62,7 +62,7 @@ bool IsNonComplexNonBatchedMatrixVectorDot(const HloInstruction* hlo) { bool HasDynamicDimensions(const Shape& shape) { for (int64_t i = 0; i < shape.dimensions().size(); ++i) { if (shape.is_dynamic_dimension(i) || - (shape.expressions(i) && shape.expressions(i)->is_dynamic())) { + shape.expressions(i)->is_dynamic()) { return true; } }