From 3d5eae5c184e147b955b8a8ab869cb576bef5d57 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Mon, 30 Mar 2026 15:56:49 +0100 Subject: [PATCH 1/3] Do not output-fuse dynamic dots on CPU --- .../xla/service/cpu/cpu_instruction_fusion.cc | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc index bd26c1c87621cd..507a1fa9ae13a4 100644 --- a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc +++ b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc @@ -59,6 +59,23 @@ bool IsNonComplexNonBatchedMatrixVectorDot(const HloInstruction* hlo) { hlo->dot_dimension_numbers().lhs_batch_dimensions_size() == 0; } +bool HasDynamicDimensions(const Shape& shape) { + for (int64_t i = 0; i < shape.dimensions().size(); ++i) { + if (shape.is_dynamic_dimension(i) || + (shape.expressions(i) && shape.expressions(i)->is_dynamic())) { + return true; + } + } + return false; +} + +bool IsDynamicDot(const HloInstruction* hlo) { + return hlo->opcode() == HloOpcode::kDot && + (HasDynamicDimensions(hlo->shape()) || + HasDynamicDimensions(hlo->operand(0)->shape()) || + HasDynamicDimensions(hlo->operand(1)->shape())); +} + bool HasExactlyOneUse(const HloInstruction& hlo_instr) { return hlo_instr.user_count() == 1 && absl::c_count(hlo_instr.users().front()->operands(), &hlo_instr) == 1; @@ -142,6 +159,10 @@ FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return FusionDecision::Forbid("Don't fuse large constants."); } + if (IsDynamicDot(producer) || IsDynamicDot(consumer)) { + return FusionDecision::Forbid("Do not fuse dynamic dots on CPU."); + } + if (CanBeOutputFused(producer, consumer)) { VLOG(2) << "Fusion OK: Can create output fusion."; return FusionDecision::Allow(); From 2c1fb9227ee16c25c1686b3b1ad2beed2167d4e2 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Mon, 30 Mar 2026 17:33:00 +0100 Subject: [PATCH 2/3] Try naive LLVM IR for dynamic dots --- .../xla/xla/service/cpu/dot_op_emitter.cc | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/cpu/dot_op_emitter.cc b/third_party/xla/xla/service/cpu/dot_op_emitter.cc index e201b87a4a2a6a..98100c2e76cc28 100644 --- a/third_party/xla/xla/service/cpu/dot_op_emitter.cc +++ b/third_party/xla/xla/service/cpu/dot_op_emitter.cc @@ -195,7 +195,11 @@ DotImplementationStrategy GetNonBatchDotImplementationStrategy( << "Dot operations must be non-batch"; if (HasDynamicMatmulDims(dot_info)) { - return DotImplementationStrategy::kEigen; + // The runtime path is the safest choice when available. When runtime calls + // are disabled, fall back to the naive LLVM IR path, which can consume + // dynamic loop bounds via dimension expressions. + return allow_runtime_calls ? DotImplementationStrategy::kEigen + : DotImplementationStrategy::kNaiveLlvmIr; } // Any Matrix-Vector product of floating point or integral type, or @@ -617,9 +621,19 @@ void DotOpEmitter::EmitNaiveLlvmIrGemm() { int64_t lhs_reduction_dimension = dim_nums.lhs_contracting_dimensions(0); int64_t rhs_reduction_dimension = dim_nums.rhs_contracting_dimensions(0); - // Verify the reduction dimension in the two operands are the same size. - CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension), - rhs_shape.dimensions(rhs_reduction_dimension)); + DynExpr* lhs_reduction_expr = lhs_shape.expressions(lhs_reduction_dimension); + DynExpr* rhs_reduction_expr = rhs_shape.expressions(rhs_reduction_dimension); + DynExpr* reduction_expr = + lhs_reduction_expr != nullptr ? lhs_reduction_expr : rhs_reduction_expr; + + // Verify the reduction dimension in the two operands are the same size when + // it is statically known. Dynamic dimensions are assumed to agree at + // runtime, as guaranteed by dot shape semantics. + if ((lhs_reduction_expr == nullptr || lhs_reduction_expr->is_constant()) && + (rhs_reduction_expr == nullptr || rhs_reduction_expr->is_constant())) { + CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension), + rhs_shape.dimensions(rhs_reduction_dimension)); + } bool lhs_reduction_along_minor_dimension = lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0); @@ -651,7 +665,8 @@ void DotOpEmitter::EmitNaiveLlvmIrGemm() { (lhs_reduction_along_minor_dimension && rhs_reduction_along_minor_dimension) ? xla::llvm_ir::UnrollMode::kNoUnroll - : xla::llvm_ir::UnrollMode::kDefaultUnroll); + : xla::llvm_ir::UnrollMode::kDefaultUnroll, + /*prevent_vectorization=*/false, reduction_expr); // The final entry in the rhs and lhs indexes is the indvar of the // reduction loop. From 13d887a9f5920b2c3ca1c411b64ff0e831fd9f20 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Mon, 30 Mar 2026 17:36:28 +0100 Subject: [PATCH 3/3] Drop dynamic dot fusion guard --- .../xla/service/cpu/cpu_instruction_fusion.cc | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc index 507a1fa9ae13a4..bd26c1c87621cd 100644 --- a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc +++ b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc @@ -59,23 +59,6 @@ bool IsNonComplexNonBatchedMatrixVectorDot(const HloInstruction* hlo) { hlo->dot_dimension_numbers().lhs_batch_dimensions_size() == 0; } -bool HasDynamicDimensions(const Shape& shape) { - for (int64_t i = 0; i < shape.dimensions().size(); ++i) { - if (shape.is_dynamic_dimension(i) || - (shape.expressions(i) && shape.expressions(i)->is_dynamic())) { - return true; - } - } - return false; -} - -bool IsDynamicDot(const HloInstruction* hlo) { - return hlo->opcode() == HloOpcode::kDot && - (HasDynamicDimensions(hlo->shape()) || - HasDynamicDimensions(hlo->operand(0)->shape()) || - HasDynamicDimensions(hlo->operand(1)->shape())); -} - bool HasExactlyOneUse(const HloInstruction& hlo_instr) { return hlo_instr.user_count() == 1 && absl::c_count(hlo_instr.users().front()->operands(), &hlo_instr) == 1; @@ -159,10 +142,6 @@ FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return FusionDecision::Forbid("Don't fuse large constants."); } - if (IsDynamicDot(producer) || IsDynamicDot(consumer)) { - return FusionDecision::Forbid("Do not fuse dynamic dots on CPU."); - } - if (CanBeOutputFused(producer, consumer)) { VLOG(2) << "Fusion OK: Can create output fusion."; return FusionDecision::Allow();