diff --git a/third_party/xla/xla/service/cpu/dot_op_emitter.cc b/third_party/xla/xla/service/cpu/dot_op_emitter.cc index e201b87a4a2a6a..98100c2e76cc28 100644 --- a/third_party/xla/xla/service/cpu/dot_op_emitter.cc +++ b/third_party/xla/xla/service/cpu/dot_op_emitter.cc @@ -195,7 +195,11 @@ DotImplementationStrategy GetNonBatchDotImplementationStrategy( << "Dot operations must be non-batch"; if (HasDynamicMatmulDims(dot_info)) { - return DotImplementationStrategy::kEigen; + // The runtime path is the safest choice when available. When runtime calls + // are disabled, fall back to the naive LLVM IR path, which can consume + // dynamic loop bounds via dimension expressions. + return allow_runtime_calls ? DotImplementationStrategy::kEigen + : DotImplementationStrategy::kNaiveLlvmIr; } // Any Matrix-Vector product of floating point or integral type, or @@ -617,9 +621,19 @@ void DotOpEmitter::EmitNaiveLlvmIrGemm() { int64_t lhs_reduction_dimension = dim_nums.lhs_contracting_dimensions(0); int64_t rhs_reduction_dimension = dim_nums.rhs_contracting_dimensions(0); - // Verify the reduction dimension in the two operands are the same size. - CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension), - rhs_shape.dimensions(rhs_reduction_dimension)); + DynExpr* lhs_reduction_expr = lhs_shape.expressions(lhs_reduction_dimension); + DynExpr* rhs_reduction_expr = rhs_shape.expressions(rhs_reduction_dimension); + DynExpr* reduction_expr = + lhs_reduction_expr != nullptr ? lhs_reduction_expr : rhs_reduction_expr; + + // Verify the reduction dimension in the two operands are the same size when + // it is statically known. Dynamic dimensions are assumed to agree at + // runtime, as guaranteed by dot shape semantics. + if ((lhs_reduction_expr == nullptr || lhs_reduction_expr->is_constant()) && + (rhs_reduction_expr == nullptr || rhs_reduction_expr->is_constant())) { + CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension), + rhs_shape.dimensions(rhs_reduction_dimension)); + } bool lhs_reduction_along_minor_dimension = lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0); @@ -651,7 +665,8 @@ void DotOpEmitter::EmitNaiveLlvmIrGemm() { (lhs_reduction_along_minor_dimension && rhs_reduction_along_minor_dimension) ? xla::llvm_ir::UnrollMode::kNoUnroll - : xla::llvm_ir::UnrollMode::kDefaultUnroll); + : xla::llvm_ir::UnrollMode::kDefaultUnroll, + /*prevent_vectorization=*/false, reduction_expr); // The final entry in the rhs and lhs indexes is the indvar of the // reduction loop.