diff --git a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc index bd26c1c87621cd..54c34104f4b7b0 100644 --- a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc +++ b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc @@ -59,6 +59,23 @@ bool IsNonComplexNonBatchedMatrixVectorDot(const HloInstruction* hlo) { hlo->dot_dimension_numbers().lhs_batch_dimensions_size() == 0; } +bool HasDynamicDimensions(const Shape& shape) { + for (int64_t i = 0; i < shape.dimensions().size(); ++i) { + if (shape.is_dynamic_dimension(i) || + shape.expressions(i)->is_dynamic()) { + return true; + } + } + return false; +} + +bool IsDynamicDot(const HloInstruction* hlo) { + return hlo->opcode() == HloOpcode::kDot && + (HasDynamicDimensions(hlo->shape()) || + HasDynamicDimensions(hlo->operand(0)->shape()) || + HasDynamicDimensions(hlo->operand(1)->shape())); +} + bool HasExactlyOneUse(const HloInstruction& hlo_instr) { return hlo_instr.user_count() == 1 && absl::c_count(hlo_instr.users().front()->operands(), &hlo_instr) == 1; @@ -142,6 +159,10 @@ FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return FusionDecision::Forbid("Don't fuse large constants."); } + if (IsDynamicDot(producer) || IsDynamicDot(consumer)) { + return FusionDecision::Forbid("Do not fuse dynamic dots on CPU."); + } + if (CanBeOutputFused(producer, consumer)) { VLOG(2) << "Fusion OK: Can create output fusion."; return FusionDecision::Allow();