Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions third_party/xla/xla/service/cpu/dot_op_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,11 @@ DotImplementationStrategy GetNonBatchDotImplementationStrategy(
<< "Dot operations must be non-batch";

if (HasDynamicMatmulDims(dot_info)) {
return DotImplementationStrategy::kEigen;
// The runtime path is the safest choice when available. When runtime calls
// are disabled, fall back to the naive LLVM IR path, which can consume
// dynamic loop bounds via dimension expressions.
return allow_runtime_calls ? DotImplementationStrategy::kEigen
: DotImplementationStrategy::kNaiveLlvmIr;
}

// Any Matrix-Vector product of floating point or integral type, or
Expand Down Expand Up @@ -617,9 +621,19 @@ void DotOpEmitter::EmitNaiveLlvmIrGemm() {
int64_t lhs_reduction_dimension = dim_nums.lhs_contracting_dimensions(0);
int64_t rhs_reduction_dimension = dim_nums.rhs_contracting_dimensions(0);

// Verify the reduction dimension in the two operands are the same size.
CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension),
rhs_shape.dimensions(rhs_reduction_dimension));
DynExpr* lhs_reduction_expr = lhs_shape.expressions(lhs_reduction_dimension);
DynExpr* rhs_reduction_expr = rhs_shape.expressions(rhs_reduction_dimension);
DynExpr* reduction_expr =
lhs_reduction_expr != nullptr ? lhs_reduction_expr : rhs_reduction_expr;

// Verify the reduction dimension in the two operands are the same size when
// it is statically known. Dynamic dimensions are assumed to agree at
// runtime, as guaranteed by dot shape semantics.
if ((lhs_reduction_expr == nullptr || lhs_reduction_expr->is_constant()) &&
(rhs_reduction_expr == nullptr || rhs_reduction_expr->is_constant())) {
CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension),
rhs_shape.dimensions(rhs_reduction_dimension));
}

bool lhs_reduction_along_minor_dimension =
lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0);
Expand Down Expand Up @@ -651,7 +665,8 @@ void DotOpEmitter::EmitNaiveLlvmIrGemm() {
(lhs_reduction_along_minor_dimension &&
rhs_reduction_along_minor_dimension)
? xla::llvm_ir::UnrollMode::kNoUnroll
: xla::llvm_ir::UnrollMode::kDefaultUnroll);
: xla::llvm_ir::UnrollMode::kDefaultUnroll,
/*prevent_vectorization=*/false, reduction_expr);

// The final entry in the rhs and lhs indexes is the indvar of the
// reduction loop.
Expand Down