From 612cb5012c4203f99f0a86caac5cb8fcf5440409 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Wed, 1 Apr 2026 17:28:59 +0100 Subject: [PATCH 1/5] Migrate DynExpr ownership to DExpr --- .../jit/encapsulate_subgraphs_pass.cc | 25 -- tensorflow/compiler/jit/kernels/xla_ops.cc | 49 ++- .../compiler/jit/mark_for_compilation_pass.cc | 18 +- tensorflow/compiler/jit/shape_inference.cc | 8 +- tensorflow/compiler/jit/xla_launch_util.cc | 12 +- .../compiler/tf2xla/kernels/bincount_op.cc | 10 +- .../tf2xla/kernels/conv_op_helpers.cc | 2 +- .../tf2xla/kernels/dynamic_partition_op.cc | 48 +-- .../tf2xla/kernels/dynamic_stitch_op.cc | 2 +- .../tf2xla/kernels/fake_quantize_ops.cc | 9 +- .../tf2xla/kernels/matrix_diag_ops.cc | 12 +- .../kernels/matrix_triangular_solve_op.cc | 18 +- tensorflow/compiler/tf2xla/kernels/pack_op.cc | 5 +- .../compiler/tf2xla/kernels/pooling_ops.cc | 4 +- .../kernels/quantize_and_dequantize_op.cc | 4 +- .../tf2xla/kernels/reduction_ops_common.cc | 4 +- .../compiler/tf2xla/kernels/reshape_op.cc | 64 ++-- .../tf2xla/kernels/reverse_sequence_op.cc | 4 +- .../compiler/tf2xla/kernels/shape_op.cc | 12 +- .../compiler/tf2xla/kernels/slice_op.cc | 27 +- .../compiler/tf2xla/kernels/split_op.cc | 36 +- .../tf2xla/kernels/stateless_random_ops.cc | 3 +- .../tf2xla/kernels/strided_slice_op.cc | 53 +-- .../tf2xla/kernels/tensor_array_ops.cc | 14 +- .../tf2xla/kernels/tensor_list_ops.cc | 16 +- .../tf2xla/kernels/tensor_list_utils.cc | 51 ++- .../compiler/tf2xla/kernels/tile_ops.cc | 4 +- .../compiler/tf2xla/kernels/unique_op.cc | 14 +- .../compiler/tf2xla/kernels/where_op.cc | 18 +- tensorflow/compiler/tf2xla/lib/broadcast.cc | 2 +- tensorflow/compiler/tf2xla/lib/broadcast.h | 2 +- tensorflow/compiler/tf2xla/lib/data_format.cc | 5 +- tensorflow/compiler/tf2xla/ops/xla_ops.cc | 13 +- tensorflow/compiler/tf2xla/shape_util.cc | 35 +- tensorflow/compiler/tf2xla/xla_argument.h | 4 +- tensorflow/compiler/tf2xla/xla_compiler.cc | 13 +- tensorflow/compiler/tf2xla/xla_expression.cc | 2 +- tensorflow/compiler/tf2xla/xla_expression.h | 10 +- tensorflow/core/framework/tensor_shape.cc | 193 +++++----- tensorflow/core/framework/tensor_shape.h | 44 +-- .../core/grappler/costs/graph_properties.cc | 7 +- tensorflow/core/kernels/padding_fifo_queue.cc | 6 +- tensorflow/core/kernels/strided_slice_op.cc | 4 +- tensorflow/core/util/strided_slice_op.cc | 88 ++--- tensorflow/core/util/strided_slice_op.h | 8 +- .../xla/xla/hlo/builder/lib/broadcast.cc | 27 +- .../xla/xla/hlo/builder/lib/broadcast.h | 2 +- third_party/xla/xla/hlo/builder/lib/matrix.cc | 6 +- third_party/xla/xla/hlo/builder/lib/prng.cc | 8 +- .../xla/xla/hlo/builder/lib/slicing.cc | 4 +- .../xla/xla/hlo/builder/xla_builder.cc | 145 +++---- third_party/xla/xla/hlo/builder/xla_builder.h | 58 +-- .../expanders/bitcast_dtypes_expander.cc | 4 +- .../transforms/expanders/dot_decomposer.cc | 36 +- .../hlo/transforms/expanders/qr_expander.cc | 18 +- .../expanders/rng_bit_generator_expander.cc | 2 +- .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 8 +- .../translate/mhlo_to_hlo/type_to_shape.cc | 2 +- .../xla/service/cpu/cpu_instruction_fusion.cc | 21 -- .../xla/xla/service/cpu/dot_op_emitter.cc | 78 ++-- third_party/xla/xla/service/cpu/ir_emitter.cc | 43 +-- third_party/xla/xla/service/cpu/ir_emitter.h | 6 +- .../xla/xla/service/cpu/ir_emitter2.cc | 3 +- .../xla/service/dynamic_constant_rewriter.cc | 34 +- .../xla/xla/service/elemental_ir_emitter.cc | 16 +- .../xla/xla/service/hlo_creation_utils.cc | 11 +- .../xla/xla/service/llvm_ir/ir_array.cc | 8 +- .../xla/xla/service/llvm_ir/llvm_loop.cc | 11 +- .../xla/xla/service/llvm_ir/llvm_loop.h | 8 +- .../xla/xla/service/llvm_ir/llvm_util.cc | 59 +-- .../xla/xla/service/llvm_ir/llvm_util.h | 4 +- .../xla/xla/service/llvm_ir/loop_emitter.cc | 2 +- .../xla/xla/service/shape_inference.cc | 99 ++--- third_party/xla/xla/service/shape_inference.h | 12 +- .../xla/service/triangular_solve_expander.cc | 4 +- third_party/xla/xla/shape.cc | 354 +++++++++++------- third_party/xla/xla/shape.h | 45 ++- third_party/xla/xla/shape_dynexpr.h | 290 ++++++++++++-- third_party/xla/xla/shape_util.cc | 51 ++- third_party/xla/xla/shape_util.h | 13 +- 80 files changed, 1336 insertions(+), 1138 deletions(-) diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index f1b81fd8e529ec..a3d7b0e71783d3 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -494,31 +494,6 @@ Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) { Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); } -void ExprToProto(xla::DynExpr* expr, ExpressionProto* proto) { - auto e = expr->s(); - if (xla::Constant* c = dynamic_cast(e)) { - proto->set_constant_value(c->get_val()); - } else if (xla::Variable* v = dynamic_cast(e)) { - proto->set_variable_id(v->get_id()); - } else if (xla::Add* a = dynamic_cast(e)) { - auto* add_msg = proto->mutable_add_node(); - ExprToProto(a->get_lhs(), add_msg->mutable_lhs()); - ExprToProto(a->get_rhs(), add_msg->mutable_rhs()); - } else if (xla::Mul* m = dynamic_cast(e)) { - auto* mul_msg = proto->mutable_mul_node(); - ExprToProto(m->get_lhs(), mul_msg->mutable_lhs()); - ExprToProto(m->get_rhs(), mul_msg->mutable_rhs()); - } else if (xla::Sub* s = dynamic_cast(e)) { - auto* sub_msg = proto->mutable_sub_node(); - ExprToProto(s->get_lhs(), sub_msg->mutable_lhs()); - ExprToProto(s->get_rhs(), sub_msg->mutable_rhs()); - } else if (xla::Div* d = dynamic_cast(e)) { - auto* div_msg = proto->mutable_div_node(); - ExprToProto(d->get_lhs(), div_msg->mutable_lhs()); - ExprToProto(d->get_rhs(), div_msg->mutable_rhs()); - } -} - absl::Status Encapsulator::Subgraph::RecordArg( const Edge* edge, const absl::flat_hash_map& node_images, diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 4dbd650754d77a..782ac342fc8991 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -419,34 +419,33 @@ std::unique_ptr ExprFromProto(const ExpressionProto& proto) { } } -static xla::DynExpr* DimExprToDynExpr(const DimExpr* e) { +static xla::DExpr DimExprToDExpr(const DimExpr* e) { switch (e->kind()) { case DimExpr::Kind::kConstant: { auto* ac = static_cast(e); - return xla::DynExpr::_(ac->value()); + return xla::DExpr::Const(ac->value()); } case DimExpr::Kind::kVariable: { - auto* av = static_cast(e); - return xla::DynExpr::V(1); + return xla::DExpr::Var(1); } case DimExpr::Kind::kAdd: { auto* ee = static_cast(e); - return *DimExprToDynExpr(ee->lhs()) + *DimExprToDynExpr(ee->rhs()); + return DimExprToDExpr(ee->lhs()) + DimExprToDExpr(ee->rhs()); } case DimExpr::Kind::kSub: { auto* ee = static_cast(e); - return *DimExprToDynExpr(ee->lhs()) - *DimExprToDynExpr(ee->rhs()); + return DimExprToDExpr(ee->lhs()) - DimExprToDExpr(ee->rhs()); } case DimExpr::Kind::kMul: { auto* ee = static_cast(e); - return *DimExprToDynExpr(ee->lhs()) * *DimExprToDynExpr(ee->rhs()); + return DimExprToDExpr(ee->lhs()) * DimExprToDExpr(ee->rhs()); } case DimExpr::Kind::kDiv: { auto* ee = static_cast(e); - return *DimExprToDynExpr(ee->lhs()) / *DimExprToDynExpr(ee->rhs()); + return DimExprToDExpr(ee->lhs()) / DimExprToDExpr(ee->rhs()); } } - return nullptr; + return xla::DExpr::Unknown(); } @@ -511,12 +510,12 @@ absl::Status CompileToLocalExecutable( int64_t dynamic_dim_value = 0; XlaBatchMatcher* xla_batch_matcher = xla_device_compiler->xla_batch_matcher(); - xla::DynExpr* dynamic_dim_expr = nullptr; - auto record_dynamic_dim_value = [&](int64_t dim_size, xla::DynExpr* expr) { + std::optional dynamic_dim_expr; + auto record_dynamic_dim_value = [&](int64_t dim_size, xla::DExpr expr) { if (!saw_dynamic_dim_value) { saw_dynamic_dim_value = true; dynamic_dim_value = dim_size; - dynamic_dim_expr = expr; + dynamic_dim_expr = std::move(expr); return; } if (dynamic_dim_value != dim_size) { @@ -545,18 +544,18 @@ absl::Status CompileToLocalExecutable( std::get(norm_args[arg_index].shape); const AttrValue& v = dyn_dim_attr->second; int64_t idx = v.i(); - record_dynamic_dim_value(shp.dim_size(idx), xla::DynExpr::V(1)); + record_dynamic_dim_value(shp.dim_size(idx), xla::DExpr::Var(1)); if (!filled_batch && xla_batch_matcher) { filled_batch = xla_batch_matcher->get_xla_compile_batch(shp.dim_size(idx)); } - std::vector dyn_exprs; + std::vector dyn_exprs; for (int d : shp.dim_sizes()) { - dyn_exprs.push_back(xla::DynExpr::_(d)); + dyn_exprs.push_back(xla::DExpr::Const(d)); } - dyn_exprs[idx] = xla::DynExpr::V(1); - shp.set_expressions(dyn_exprs); + dyn_exprs[idx] = xla::DExpr::Var(1); + shp.set_expressions(std::move(dyn_exprs)); continue; } auto it = attr_map.find(kXlaInferredOutputShapesAttrName); @@ -570,7 +569,7 @@ absl::Status CompileToLocalExecutable( for (int idx = 0; idx < exp.size(); ++idx) { // Look for dynamic expression. If found then compute padding // value and exit loop. - auto e = DimExprToDynExpr(ExprFromProto(exp[idx]).get())->s(); + auto e = DimExprToDExpr(ExprFromProto(exp[idx]).get()).simplify(); if (e->is_dynamic()) { std::optional solved_value = e->solve(shp.dim_size(idx)); @@ -595,17 +594,17 @@ absl::Status CompileToLocalExecutable( } } - std::vector dyn_exprs; + std::vector dyn_exprs; for (int d : shp.dim_sizes()) { - dyn_exprs.push_back(xla::DynExpr::_(d)); + dyn_exprs.push_back(xla::DExpr::Const(d)); } for (int j = 0; j < exp.size(); ++j) { - auto e = DimExprToDynExpr(ExprFromProto(exp[j]).get())->s(); + auto e = DimExprToDExpr(ExprFromProto(exp[j]).get()).simplify(); if (e->is_dynamic()) { dyn_exprs[j] = e; } } - shp.set_expressions(dyn_exprs); + shp.set_expressions(std::move(dyn_exprs)); } } } @@ -691,8 +690,8 @@ absl::Status CompileToLocalExecutable( if (e->is_dynamic()) { int64_t old = shp.dim_size(j); old_vars.push_back({i, j, old}); - xla::DynExpr* padded_expr = xla::DynExpr::_(filled_batch); - xla::DynExpr* subst_expr = e->substitute(1, padded_expr)->s(); + xla::DExpr padded_expr = xla::DExpr::Const(filled_batch); + xla::DExpr subst_expr = e.substitute(1, padded_expr).simplify(); int64_t new_dim = subst_expr->get_val(); if (new_dim >= 0) { shp.set_dim(j, new_dim); @@ -1266,7 +1265,7 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { if (!xla_shape.IsArray() || xla_shape.expressions().empty()) continue; for (int dim = 0; dim < xla_shape.expressions().size(); dim++) { - xla::DynExpr* expr = xla_shape.expressions(dim); + const auto& expr = xla_shape.expressions(dim); if (expr && expr->is_dynamic()) { int input_idx = comp_result->input_mapping[i] - num_constant_args; if (input_idx < 0 || input_idx >= ctx->num_inputs()) { diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index e8f141b9ef86b7..54d0b51c78bc44 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -753,34 +753,34 @@ std::unique_ptr ExprFromProto(const ExpressionProto& proto) { } } -static xla::DynExpr* DimExprToDynExpr(const DimExpr* e) { +static xla::DExpr DimExprToDExpr(const DimExpr* e) { switch (e->kind()) { case DimExpr::Kind::kConstant: { auto* ac = static_cast(e); - return xla::DynExpr::_(ac->value()); + return xla::DExpr::Const(ac->value()); } case DimExpr::Kind::kVariable: { auto* av = static_cast(e); - return xla::DynExpr::V(av->id()); // Use 1 all the time for now + return xla::DExpr::Var(av->id()); // Use 1 all the time for now } case DimExpr::Kind::kAdd: { auto* ee = static_cast(e); - return *DimExprToDynExpr(ee->lhs()) + *DimExprToDynExpr(ee->rhs()); + return DimExprToDExpr(ee->lhs()) + DimExprToDExpr(ee->rhs()); } case DimExpr::Kind::kSub: { auto* ee = static_cast(e); - return *DimExprToDynExpr(ee->lhs()) - *DimExprToDynExpr(ee->rhs()); + return DimExprToDExpr(ee->lhs()) - DimExprToDExpr(ee->rhs()); } case DimExpr::Kind::kMul: { auto* ee = static_cast(e); - return *DimExprToDynExpr(ee->lhs()) * *DimExprToDynExpr(ee->rhs()); + return DimExprToDExpr(ee->lhs()) * DimExprToDExpr(ee->rhs()); } case DimExpr::Kind::kDiv: { auto* ee = static_cast(e); - return *DimExprToDynExpr(ee->lhs()) / *DimExprToDynExpr(ee->rhs()); + return DimExprToDExpr(ee->lhs()) / DimExprToDExpr(ee->rhs()); } } - return nullptr; + return xla::DExpr(); } // Runs Grappler static inference and logs any ExpressionProto found in output @@ -1879,7 +1879,7 @@ absl::Status MarkForCompilationPassImpl::AssignDimVars(void) { } for (auto& pDim: (it->second)[output_index]) { DimExpr * d= pDim.get(); - xla::DynExpr * dyn = DimExprToDynExpr(d); + xla::DExpr dyn = DimExprToDExpr(d); auto new_ids = dyn->get_all_ids(); for (auto id : new_ids) { cluster->add_dim_var(id); diff --git a/tensorflow/compiler/jit/shape_inference.cc b/tensorflow/compiler/jit/shape_inference.cc index dabbceb65076d4..2dd92e43f6976a 100644 --- a/tensorflow/compiler/jit/shape_inference.cc +++ b/tensorflow/compiler/jit/shape_inference.cc @@ -51,7 +51,7 @@ absl::Status ShapeHandleToTensorShape( std::vector dims(context->Rank(handle)); MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); - std::vector dyn_exprs; + std::vector dyn_exprs; if (flags->tf_xla_enable_dynamic_sizes) { dyn_exprs.resize(context->Rank(handle)); } @@ -59,14 +59,14 @@ absl::Status ShapeHandleToTensorShape( dims[i] = context->Value(context->Dim(handle, i)); if (flags->tf_xla_enable_dynamic_sizes) { auto ratio = context->DynamicRatio(context->Dim(handle, i)); - dyn_exprs[i] = ratio > 0 ? (ratio * *xla::DynExpr::V(1))->s() - : xla::DynExpr::_(dims[i]); + dyn_exprs[i] = ratio > 0 ? xla::DExpr::Const(ratio) * xla::DExpr::Var(1) + : xla::DExpr::Const(dims[i]); } } auto status = PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape); if (flags->tf_xla_enable_dynamic_sizes) { - shape->set_expressions(dyn_exprs); + shape->set_expressions(std::move(dyn_exprs)); } return status; } diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 06196e8d78f26c..ccf97cb24d75f8 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -441,13 +441,13 @@ absl::Status XlaComputationLaunchContext::PopulateOutputs( bool has_dynamic = false; for (int dim = 0; dim < subshape.expressions().size(); ++dim) { - auto expr = subshape.expressions(dim); - if (expr != nullptr && expr->is_dynamic()) { + const auto& expr = subshape.expressions(dim); + if (expr && expr->is_dynamic()) { has_dynamic = true; VLOG(1) << "Current expression is " << expr; if (run_options) { - xla::DynExpr* batch_size = xla::DynExpr::_(run_options->batch_size()); - xla::DynExpr* subst_expr = expr->substitute(1, batch_size)->s(); + xla::DExpr batch_size = xla::DExpr::Const(run_options->batch_size()); + xla::DExpr subst_expr = expr.substitute(1, batch_size).simplify(); shape.set_dim(dim, subst_expr->get_val()); } else { // TODO: Fallback to BatchSizeResource for now. Remove it later. @@ -456,9 +456,9 @@ absl::Status XlaComputationLaunchContext::PopulateOutputs( ScopedStepContainer* step_container = ctx->step_container(); TF_RETURN_IF_ERROR(step_container->Lookup( ctx->resource_manager(), BatchSizeResourceName, &bsr)); - xla::DynExpr* batch_size = xla::DynExpr::_(bsr->GetBatchSize()); + xla::DExpr batch_size = xla::DExpr::Const(bsr->GetBatchSize()); // Just substitute Var(1) for now. - xla::DynExpr* subst_expr = expr->substitute(1, batch_size)->s(); + xla::DExpr subst_expr = expr.substitute(1, batch_size).simplify(); shape.set_dim(dim, subst_expr->get_val()); bsr->Unref(); } diff --git a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc index 4f31c79f91a719..1038a73fc897b8 100644 --- a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc @@ -116,14 +116,14 @@ class DenseBincountOp : public XlaOpKernel { auto i_shape = xla::ShapeUtil::MakeShape(input_xla_type, {input_shape.dimensions()}); auto i = xla::Iota(ctx->builder(), i_shape, 0); + xla::DExpr flattened_expr = + (input_shape.expressions(0) * input_shape.expressions(1)).simplify(); i = xla::Reshape( i, {input_shape.dimensions(0) * input_shape.dimensions(1), 1}, - {(*input_shape.expressions(0) * *input_shape.expressions(1))->s(), - xla::DynExpr::one}); + {flattened_expr, xla::DExpr::Const(1)}); auto j = xla::Reshape( input, {input_shape.dimensions(0) * input_shape.dimensions(1), 1}, - {(*input_shape.expressions(0) * *input_shape.expressions(1))->s(), - xla::DynExpr::one}); + {flattened_expr, xla::DExpr::Const(1)}); std::vector iotas_to_concat; iotas_to_concat.push_back(i); iotas_to_concat.push_back(j); @@ -135,7 +135,7 @@ class DenseBincountOp : public XlaOpKernel { if (has_weights && !binary_output_) { weights = xla::Reshape( weights, {input_shape.dimensions(0) * input_shape.dimensions(1)}, - {(*input_shape.expressions(0) * *input_shape.expressions(1))->s()}); + {flattened_expr}); updates = weights; } } else { diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 466707f0d777e2..5fa0d823a72e69 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -99,7 +99,7 @@ xla::XlaOp TransposeFilterForGroupConvolutionBackpropInput( new_shape.set_dimensions(num_dims - 1, num_groups); new_shape.add_dimensions( filter_shape.dimensions(num_dims - 1) / num_groups, - (*filter_shape.expressions(num_dims - 1) / num_groups)->s()); + (filter_shape.expressions(num_dims - 1) / num_groups).simplify()); xla::XlaOp result = xla::Reshape(filter, new_shape.dimensions(), new_shape.expressions()); diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc index 8988bd44528f00..c5a11a0fc5df5f 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc @@ -43,23 +43,22 @@ limitations under the License. namespace tensorflow { namespace { -std::vector GetFilledExpressions(const xla::Shape& shape) { - std::vector expressions; +std::vector GetFilledExpressions(const xla::Shape& shape) { + std::vector expressions; expressions.reserve(shape.dimensions_size()); auto shape_expressions = shape.expressions(); for (int64_t i = 0; i < shape.dimensions_size(); ++i) { - expressions.push_back( - i < shape_expressions.size() && shape_expressions[i] != nullptr - ? shape_expressions[i] - : xla::DynExpr::_(shape.dimensions(i))); + expressions.push_back(i < shape_expressions.size() && shape_expressions[i] + ? shape_expressions[i] + : xla::DExpr::Const(shape.dimensions(i))); } return expressions; } -xla::DynExpr* CollapseExpressions(absl::Span expressions) { - xla::DynExpr* collapsed = xla::DynExpr::one; - for (xla::DynExpr* expression : expressions) { - collapsed = (*collapsed * *expression)->s(); +xla::DExpr CollapseExpressions(absl::Span expressions) { + xla::DExpr collapsed = xla::DExpr::Const(1); + for (const xla::DExpr& expression : expressions) { + collapsed = (collapsed * expression).simplify(); } return collapsed; } @@ -88,19 +87,7 @@ class DynamicPartitionOp : public XlaOpKernel { xla::XlaOp partitions_1d, const xla::Shape& data_1d_shape, const xla::Shape& partition_1d_shape) { int64_t input_count = data_1d_shape.dimensions(0); - VLOG(1) << "data_1d_shape=" - << xla::ShapeUtil::HumanString(data_1d_shape) - << " partition_1d_shape=" - << xla::ShapeUtil::HumanString(partition_1d_shape) - << " input_count=" << input_count - << " num_partitions=" << num_partitions_; - // TODO: Use the smaller runtime size to mask off padded tail elements when one - // input loses dynamic shape information and falls back to its static bound. - // For valid DynamicPartition inputs the true prefix sizes should match, so - // this keeps us from treating padding on either side as real elements. - xla::XlaOp dynamic_input_count = xla::Min( - xla::GetDimensionSize(data_1d, 0), - xla::GetDimensionSize(partitions_1d, 0)); + xla::XlaOp dynamic_input_count = xla::GetDimensionSize(data_1d, 0); xla::XlaOp input_index = xla::Iota(ctx->builder(), xla::S32, input_count); xla::XlaOp valid_element = xla::Lt(input_index, dynamic_input_count); xla::XlaOp invalid_partition = @@ -212,16 +199,18 @@ class DynamicPartitionOp : public XlaOpKernel { auto partitions_1d = xla::Reshape( partitions, {input_count}, {CollapseExpressions(flattened_partition_exprs)}); - xla::Shape data_1d_shape = - xla::ShapeUtil::MakeShape(data_shape.element_type(), {input_count}); + xla::Shape data_1d_shape = xla::ShapeUtil::MakeShape( + data_shape.element_type(), {input_count}, + {xla::DExpr::Const(input_count)}); xla::Shape partitions_1d_shape = xla::ShapeUtil::MakeShape( - partition_shape.element_type(), {input_count}); + partition_shape.element_type(), {input_count}, + {xla::DExpr::Const(input_count)}); std::vector output, partition_length; std::tie(output, partition_length) = DynamicPartition1D( ctx, data_1d, partitions_1d, data_1d_shape, partitions_1d_shape); - std::vector output_exprs; + std::vector output_exprs; output_exprs.reserve(data_shape.dimensions().size() - partition_shape.dimensions().size() + 1); output_exprs.push_back(CollapseExpressions(absl::MakeConstSpan( @@ -232,8 +221,9 @@ class DynamicPartitionOp : public XlaOpKernel { output_exprs.push_back(data_exprs[i]); } for (int64_t i = 0; i < num_partitions_; ++i) { - auto reshape = - xla::Reshape(output[i], output_shape_bound_dims, output_exprs); + xla::Shape output_shape = xla::ShapeUtil::MakeShape( + data_shape.element_type(), output_shape_bound_dims, output_exprs); + auto reshape = xla::Reshape(output_shape, output[i]); if (partitions_are_static) { int64_t size = absl::c_count(partitions_static, i); ctx->SetOutput(i, xla::SliceInDim(reshape, 0, size, 1, 0)); diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 1ecaabe0f9a046..305b527cc76632 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -175,7 +175,7 @@ class DynamicStitchOp : public XlaOpKernel { // first reshaped dimension is the number of indices for this input. new_shape.AddDim(indices[input_num].shape().dimensions(0)); new_shape.AddExpression( - xla::DynExpr::_(indices[input_num].shape().dimensions(0))); + xla::DExpr::Const(indices[input_num].shape().dimensions(0))); // Then the rest are the common extra shape. for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { new_shape.AddDim(data0_shape.dim_size(d)); diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 0b2d783d6aa2a6..e7bc0890317316 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -270,12 +270,10 @@ class FakeQuantWithMinMaxVarsPerChannelOp : public XlaOpKernel { xla::Shape input_shape = b->GetShape(input).value(); absl::Span input_dimensions = input_shape.dimensions(); - absl::Span input_expressions = - input_shape.expressions(); auto convert_to_input_shape = [&](const xla::XlaOp op) { return xla::BroadcastInDim(op, input_dimensions, {input_shape.dimensions_size() - 1}, - input_expressions); + input_shape.expressions()); }; input_min = convert_to_input_shape(input_min); input_max = convert_to_input_shape(input_max); @@ -328,9 +326,6 @@ class FakeQuantWithMinMaxVarsPerChannelGradOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); xla::Shape input_shape = b->GetShape(input).value(); absl::Span input_dimensions = input_shape.dimensions(); - absl::Span input_expressions = - input_shape.expressions(); - std::vector reduce_axes; for (int64_t i = 0; i + 1 < input_shape.dimensions_size(); ++i) { reduce_axes.push_back(i); @@ -339,7 +334,7 @@ class FakeQuantWithMinMaxVarsPerChannelGradOp : public XlaOpKernel { auto convert_to_input_shape = [&](const xla::XlaOp op) { return xla::BroadcastInDim(op, input_dimensions, {input_shape.dimensions_size() - 1}, - input_expressions); + input_shape.expressions()); }; input_min = convert_to_input_shape(input_min); input_max = convert_to_input_shape(input_max); diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc index ad67310aad9955..8ef6dfbd9a1702 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc @@ -328,9 +328,9 @@ class MatrixDiagOp : public XlaOpKernel { TensorShape output_shape = diag_shape; output_shape.RemoveLastDims((num_diags == 1) ? 1 : 2); output_shape.AddDim(num_rows); - output_shape.AddExpression(xla::DynExpr::_(num_rows)); + output_shape.AddExpression(xla::DExpr::Const(num_rows)); output_shape.AddDim(num_cols); - output_shape.AddExpression(xla::DynExpr::_(num_cols)); + output_shape.AddExpression(xla::DExpr::Const(num_cols)); xla::XlaOp output = xla::Broadcast(padding_value, output_shape.dim_sizes(), output_shape.get_filled_expressions()); xla::XlaOp diag = context->Input(0); @@ -410,13 +410,13 @@ class MatrixDiagPartOp : public XlaOpKernel { const int num_diags = upper_diag_index - lower_diag_index + 1; if (num_diags > 1) { output_shape.AddDim(num_diags); - output_shape.AddExpression(xla::DynExpr::_(num_diags)); + output_shape.AddExpression(xla::DExpr::Const(num_diags)); } const int32_t max_diag_len = std::min(num_rows + std::min(upper_diag_index, int64_t{0}), num_cols - std::max(lower_diag_index, int64_t{0})); output_shape.AddDim(max_diag_len); - output_shape.AddExpression(xla::DynExpr::_(max_diag_len)); + output_shape.AddExpression(xla::DExpr::Const(max_diag_len)); // Computes output. xla::XlaOp input = context->Input(0); @@ -530,13 +530,13 @@ class MatrixSetDiagOp : public XlaOpKernel { expected_diag_shape.RemoveLastDims(2); if (num_diags > 1) { expected_diag_shape.AddDim(num_diags); - expected_diag_shape.AddExpression(xla::DynExpr::_(num_diags)); + expected_diag_shape.AddExpression(xla::DExpr::Const(num_diags)); } const int32_t max_diag_len = std::min(num_rows + std::min(upper_diag_index, int64_t{0}), num_cols - std::max(lower_diag_index, int64_t{0})); expected_diag_shape.AddDim(max_diag_len); - expected_diag_shape.AddExpression(xla::DynExpr::_(max_diag_len)); + expected_diag_shape.AddExpression(xla::DExpr::Const(max_diag_len)); OP_REQUIRES( context, expected_diag_shape == diag_shape, errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index 8f913bcd3c7d83..207bc6f7ce8984 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -96,11 +96,12 @@ MatrixTriangularSolveOp::Broadcast(xla::XlaOp lhs, const TensorShape& lhs_shape, TensorShape lhs_broadcast_shape(broadcast_helper.output_batch_shape()); lhs_broadcast_shape.AddDim(m); - lhs_broadcast_shape.AddExpression(xla::DynExpr::_(m)); + lhs_broadcast_shape.AddExpression(xla::DExpr::Const(m)); lhs_broadcast_shape.AddDim(m); - lhs_broadcast_shape.AddExpression(xla::DynExpr::_(m)); - auto lhs_output = BroadcastTo(lhs, lhs_broadcast_shape.dim_sizes(), - lhs_broadcast_shape.get_filled_expressions()); + lhs_broadcast_shape.AddExpression(xla::DExpr::Const(m)); + auto lhs_output = + BroadcastTo(lhs, lhs_broadcast_shape.dim_sizes(), + lhs_broadcast_shape.get_filled_expressions()); if (!lhs_output.ok()) { xla::XlaOp error = lhs.builder()->ReportError(lhs_output.status()); return {error, error}; @@ -108,11 +109,12 @@ MatrixTriangularSolveOp::Broadcast(xla::XlaOp lhs, const TensorShape& lhs_shape, TensorShape rhs_broadcast_shape(broadcast_helper.output_batch_shape()); rhs_broadcast_shape.AddDim(m); - rhs_broadcast_shape.AddExpression(xla::DynExpr::_(m)); + rhs_broadcast_shape.AddExpression(xla::DExpr::Const(m)); rhs_broadcast_shape.AddDim(n); - rhs_broadcast_shape.AddExpression(xla::DynExpr::_(n)); - auto rhs_output = BroadcastTo(rhs, rhs_broadcast_shape.dim_sizes(), - rhs_broadcast_shape.get_filled_expressions()); + rhs_broadcast_shape.AddExpression(xla::DExpr::Const(n)); + auto rhs_output = + BroadcastTo(rhs, rhs_broadcast_shape.dim_sizes(), + rhs_broadcast_shape.get_filled_expressions()); if (!rhs_output.ok()) { xla::XlaOp error = rhs.builder()->ReportError(rhs_output.status()); return {error, error}; diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index ea239f78e32d4f..399d26b6f55de0 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -64,10 +64,9 @@ class PackOp : public XlaOpKernel { std::vector reshaped_inputs(num); TensorShape child_shape(shapes[0]); - std::vector exprs = shapes[0].get_filled_expressions(); + std::vector exprs = child_shape.get_filled_expressions(); child_shape.InsertDim(axis, 1); - exprs.insert(exprs.begin() + axis, xla::DynExpr::one); - + exprs.insert(exprs.begin() + axis, xla::DExpr::Const(1)); for (int i = 0; i < num; ++i) { // Reshape the inputs to have an extra dimension of size 1. reshaped_inputs[i] = xla::Reshape(values[i], child_shape.dim_sizes(), diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index b3cdb7bac0c7dc..b57d546fcffb0d 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -264,11 +264,11 @@ class MaxPoolOp : public PoolingOp { absl::InlinedVector new_dims(result_shape->dimensions().begin(), result_shape->dimensions().end()); - absl::InlinedVector new_exprs( + absl::InlinedVector new_exprs( result_shape->expressions().begin(), result_shape->expressions().end()); new_dims[1] /= *vect_width; - new_exprs[1] = *new_exprs[1] / *vect_width; + new_exprs[1] = (new_exprs[1] / xla::DExpr::Const(*vect_width)).simplify(); new_dims.insert(new_dims.begin() + 2, *vect_width); pooling = xla::Transpose(xla::Reshape(pooling, new_dims, new_exprs), {0, 1, 3, 4, 2}); diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index ba6860caa9cb1d..6b7e990f7030fd 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -173,11 +173,9 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { if (!xla::ShapeUtil::IsScalar(axis_shape)) { xla::Shape input_shape = b->GetShape(input).value(); absl::Span input_dimensions = input_shape.dimensions(); - absl::Span input_expressions = - input_shape.expressions(); auto convert_to_input_shape = [&](const xla::XlaOp op) { return xla::BroadcastInDim(op, input_dimensions, {axis_}, - input_expressions); + input_shape.expressions()); }; min_range = convert_to_input_shape(min_range); max_range = convert_to_input_shape(max_range); diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index de50031215b9fb..6a5f40441beff3 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -106,7 +106,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { } std::vector final_shape; - std::vector final_exprs; + std::vector final_exprs; for (int i = 0; i < data_shape.dims(); ++i) { if (!bitmap[i]) { // If we are not reducing along dimension i. @@ -118,7 +118,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { // same number of dimensions, so we set the dimension of i to // '1'. final_shape.push_back(1); - final_exprs.push_back(xla::DynExpr::one); + final_exprs.push_back(xla::DExpr::Const(1)); } } diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index 9e142f9f9aff6a..c019f42927bb9b 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -57,7 +57,7 @@ class ReshapeOp : public XlaOpKernel { // is one. TensorShape shape; int64_t product = 1; - xla::DynExpr* product_expr = xla::DynExpr::one; + xla::DExpr product_expr = xla::DExpr::Const(1); int unknown_index = -1; bool shape_has_zero_dim = false; int ratio = 1; @@ -70,41 +70,39 @@ class ReshapeOp : public XlaOpKernel { unknown_index, " and ", d)); unknown_index = d; shape.AddDim(1); - shape.AddExpression(xla::DynExpr::one); + shape.AddExpression(xla::DExpr::Const(1)); ratio = 1; } else if (size == 0) { // We don't include zero-sized dimension in product, so that we can // still calculate number of elements for non-zero-sized dimensions and // therefore infer their shapes. shape.AddDim(size); - shape.AddExpression(xla::DynExpr::_(size)); + shape.AddExpression(xla::DExpr::Const(size)); shape_has_zero_dim = true; } else { - xla::DynExpr* size_expr; + xla::DExpr size_expr; OP_REQUIRES(ctx, size >= 0, errors::InvalidArgument( "size ", d, " must be non-negative, not ", size)); shape.AddDim(size); - xla::DynExpr* input_expr = - d < input_shape.dims() ? input_shape.get_filled_expression(d) : nullptr; - if (input_expr != nullptr && input_expr->is_dynamic()) { + if (d < input_shape.dims() && input_shape.get_filled_expression(d) && + input_shape.get_filled_expression(d)->is_dynamic()) { int old = input_shape.dim_size(d); bool is_split = (old > size); int local_ratio = ratio * (is_split ? old / size : size / old); - xla::DynExpr* new_expr = - (size > old) - ? *input_expr * - *xla::DynExpr::_(local_ratio) // Split [xy] -> [x/y,y] - : *input_expr / - *xla::DynExpr::_(local_ratio); // Reduce [x,y] -> [x*y] + xla::DExpr input_dexpr = input_shape.get_filled_expression(d); + xla::DExpr ratio_expr = xla::DExpr::Const(local_ratio); + xla::DExpr new_expr = + (size > old) ? input_dexpr * ratio_expr // Reduce [x,y] -> [x*y] + : input_dexpr / ratio_expr; // Split [xy] -> [x/y,y] // Pass ratio to next dimension if this is a split, otherwise just // reset it to 1. ratio = is_split ? local_ratio : 1; - size_expr = new_expr->s(); + size_expr = new_expr.simplify(); } else { - size_expr = xla::DynExpr::_(size); + size_expr = xla::DExpr::Const(size); if (ratio != 1) { // A split dynamic dimension can be materialized by multiple later // known dimensions. Any unresolved remainder is kept in `ratio` @@ -117,15 +115,15 @@ class ReshapeOp : public XlaOpKernel { } } } - shape.AddExpression(size_expr); product *= size; - product_expr = (*product_expr * *size_expr); + product_expr = product_expr * size_expr; + shape.AddExpression(size_expr); } } auto input = ctx->Input(0); if (unknown_index != -1) { int64_t input_num_elements = 1; - xla::DynExpr* input_num_elements_expr = xla::DynExpr::one; + xla::DExpr input_num_elements_expr = xla::DExpr::Const(1); bool input_has_zero_dim = false; for (int dim = 0; dim < input_shape.dims(); dim++) { // For zero dimension, we don't count it into `input_num_elements` @@ -134,16 +132,17 @@ class ReshapeOp : public XlaOpKernel { if (input_shape.dim_size(dim) > 0 || !shape_has_zero_dim) { input_num_elements *= input_shape.dim_size(dim); input_num_elements_expr = - (*input_num_elements_expr * *input_shape.get_filled_expression(dim))->s(); + (input_num_elements_expr * input_shape.get_filled_expression(dim)) + .simplify(); } else { input_has_zero_dim = true; } } int64_t missing = input_num_elements / product; - input_num_elements_expr = input_num_elements_expr->s(); - product_expr = product_expr->s(); - auto missing_expr = *input_num_elements_expr / *product_expr; + input_num_elements_expr = input_num_elements_expr.simplify(); + product_expr = product_expr.simplify(); + auto missing_expr = input_num_elements_expr / product_expr; if (!input_has_zero_dim) { if (input_xla_shape->is_static() || input_xla_shape->dimensions().size() != 1) { @@ -168,14 +167,14 @@ class ReshapeOp : public XlaOpKernel { // This expression only approximates the padded size: the true value // uses ceil(input_num_elements / product) * product, which we do not // model symbolically here. - xla::DynExpr* padded_input_num_expr = - (*(*input_num_elements_expr / *product_expr) * *product_expr)->s(); + xla::DExpr padded_input_num_expr = + ((input_num_elements_expr / product_expr) * product_expr) + .simplify(); input_shape.set_expression(0, padded_input_num_expr); } } shape.set_dim(unknown_index, missing); - shape.set_expression( - unknown_index, missing_expr->s()); + shape.set_expression(unknown_index, missing_expr.simplify()); } OP_REQUIRES(ctx, shape.num_elements() == input_shape.num_elements(), @@ -195,14 +194,14 @@ class ReshapeOp : public XlaOpKernel { std::vector output_dim_sizes; std::vector dims_are_dynamic; - std::vector output_dim_exprs; + std::vector output_dim_exprs; const auto& dims = shape.dims(); dims_are_dynamic.reserve(dims); output_dim_sizes.reserve(dims); for (int64_t i = 0; i < dims; ++i) { output_dim_sizes.push_back( xla::Reshape(xla::Slice(ctx->Input(1), {i}, {i + 1}, {1}), {})); - output_dim_exprs.push_back(xla::DynExpr::_(-111)); + output_dim_exprs.push_back(xla::DExpr::Unknown(111)); } OP_REQUIRES_OK( ctx, ctx->ResolveInputDynamismIntoPredVector(1, &dims_are_dynamic)); @@ -225,26 +224,27 @@ class ReshapeOp : public XlaOpKernel { // reshape(Tensor([2, 3, 3]), [3, -1, 3]) product of the group // containing -1 will be 6. xla::XlaOp product = xla::One(ctx->builder(), xla::S32); - xla::DynExpr* product_expr = xla::DynExpr::one; + xla::DExpr product_expr = xla::DExpr::Const(1); for (int64_t dim = start.first; dim < end.first; ++dim) { if (input_xla_shape->is_dynamic_dimension(dim)) { input_is_dynamic = true; } product = xla::Mul(product, xla::GetDimensionSize(input, dim)); - product_expr = (*product_expr * *input_shape.get_filled_expression(dim))->s(); + product_expr = + (product_expr * input_shape.get_filled_expression(dim)).simplify(); } bool unknown_dim_in_group = false; // The real size for the -1 dimension in a reshape. E.g., in // reshape(Tensor([2, 3, 3]), [3, -1, 3]) this will be 2. xla::XlaOp unknown_dim_size = product; - xla::DynExpr* unknown_dim_expr = product_expr; + xla::DExpr unknown_dim_expr = product_expr; for (int64_t dim = start.second; dim < end.second; ++dim) { if (dim == unknown_index) { unknown_dim_in_group = true; } else { unknown_dim_size = xla::Div(unknown_dim_size, output_dim_sizes[dim]); unknown_dim_expr = - (*unknown_dim_expr / *output_dim_exprs[dim])->s(); + (unknown_dim_expr / output_dim_exprs[dim]).simplify(); } } diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 8e6d6c09967319..cb6f8cebf0a8d9 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -89,14 +89,14 @@ class ReverseSequenceOp : public XlaOpKernel { xla::ShapeUtil::MakeShape(seq_lens_type, {batch_size, max_seq_len, 1}, {input_shape.get_filled_expression(batch_dim_), input_shape.get_filled_expression(seq_dim_), - xla::DynExpr::one}), + xla::DExpr::Const(1)}), /*iota_dimension=*/0); xla::XlaOp forward_idx = xla::Iota( builder, xla::ShapeUtil::MakeShape(seq_lens_type, {batch_size, max_seq_len, 1}, {input_shape.get_filled_expression(batch_dim_), input_shape.get_filled_expression(seq_dim_), - xla::DynExpr::one}), + xla::DExpr::Const(1)}), /*iota_dimension=*/1); xla::XlaOp reverse_idx = xla::Sub(back, forward_idx, {0}); reverse_idx = xla::Select(xla::Lt(reverse_idx, xla::ZerosLike(reverse_idx)), diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index fb1df805d26b68..510da9023f6348 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -300,9 +300,10 @@ class ExpandDimsOp : public XlaOpKernel { } const int existing_exprs_size = static_cast(existing_exprs.size()); - std::vector new_exprs(existing_exprs_size); - for (size_t i = 0; i < new_exprs.size(); ++i) { - new_exprs[i] = existing_exprs[i]; + std::vector new_exprs; + new_exprs.reserve(existing_exprs_size); + for (size_t i = 0; i < existing_exprs.size(); ++i) { + new_exprs.push_back(existing_exprs[i]); } // We emulate numpy's interpretation of the dim axis when @@ -314,8 +315,7 @@ class ExpandDimsOp : public XlaOpKernel { // Clamp to the end if needed. dim = std::min(dim, existing_dims_size); new_shape.emplace(new_shape.begin() + dim, 1); - new_exprs.emplace(new_exprs.begin() + dim, xla::DynExpr::one); - + new_exprs.emplace(new_exprs.begin() + dim, xla::DExpr::Const(1)); ctx->SetOutput(0, xla::Reshape(ctx->Input("input"), new_shape, new_exprs)); } }; @@ -340,7 +340,7 @@ class SqueezeOp : public XlaOpKernel { absl::flat_hash_set wrapped_squeeze_dims; wrapped_squeeze_dims.reserve(squeeze_dims_.size()); std::vector new_shape; - std::vector new_exprs; + std::vector new_exprs; // Validate squeeze dims against the input. for (int32_t dim : squeeze_dims_) { OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 67ed6bfe09ce59..57f48987e8a6b2 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -66,17 +66,17 @@ class SliceOp : public XlaOpKernel { ctx->ConstantInputAsIntVector(2, &size).ok(); if (all_begins_are_constant && all_sizes_are_constant) { std::vector wrapped_size(size.size()); - std::vector wrapped_size_exprs(size.size()); + std::vector wrapped_size_exprs(size.size()); // `begin` is a compile-time constant. for (int i = 0; i < input_dims; ++i) { if (size[i] == -1) { // A size[i] of -1 means "all elements from begin[i] to dim_size(i)". wrapped_size[i] = input_shape.dim_size(i) - begin[i]; wrapped_size_exprs[i] = - (*input_shape.get_filled_expression(i) - begin[i])->s(); + (input_shape.get_filled_expression(i) - begin[i]).simplify(); } else { wrapped_size[i] = size[i]; - wrapped_size_exprs[i] = xla::DynExpr::_(size[i]); + wrapped_size_exprs[i] = xla::DExpr::Const(size[i]); } } @@ -101,17 +101,17 @@ class SliceOp : public XlaOpKernel { } } - std::vector begin_exprs; + std::vector begin_exprs; for (int d : begin){ - begin_exprs.push_back(xla::DynExpr::_(d)); + begin_exprs.push_back(xla::DExpr::Const(d)); } std::vector limits; - std::vector exprs; + std::vector exprs; limits.reserve(begin.size()); exprs.reserve(begin.size()); for (int i = 0; i < begin.size(); ++i) { limits.push_back(begin[i] + wrapped_size[i]); - exprs.push_back((*begin_exprs[i] + *wrapped_size_exprs[i])->s()); + exprs.push_back((begin_exprs[i] + wrapped_size_exprs[i]).simplify()); } std::vector strides(begin.size(), 1); auto slice = @@ -127,8 +127,8 @@ class SliceOp : public XlaOpKernel { // If there is a dynamic dimension, properly set dimension size of // the slice. auto dynamic_size = xla::Reshape( - xla::Slice(ctx->Input(2), {i}, {i + 1}, {xla::DynExpr::_(i)}, - {xla::DynExpr::_(i + 1)}, {1}), + xla::Slice(ctx->Input(2), {i}, {i + 1}, {xla::DExpr::Const(i)}, + {xla::DExpr::Const(i + 1)}, {1}), {}); slice = xla::SetDimensionSize(slice, dynamic_size, i); @@ -167,13 +167,14 @@ class SliceOp : public XlaOpKernel { } if (all_sizes_are_constant && !constant_size_is_minus_one) { xla::XlaOp input = ctx->Input(0); - std::vector output_exprs; + std::vector output_exprs; output_exprs.reserve(size.size()); for (int64_t d : size) { - output_exprs.push_back(xla::DynExpr::_(d)); + output_exprs.push_back(xla::DExpr::Const(d)); } - ctx->SetOutput( - 0, xla::DynamicSlice(input, begin_indices, size, output_exprs)); + ctx->SetOutput(0, + xla::DynamicSlice(input, begin_indices, size, + output_exprs)); } else { // Size is not constant, use input size as upperbound and then set // dimension size on it. diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index e23eda1c506a7d..54bc07a313d4c6 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -81,21 +81,25 @@ class SplitOp : public XlaOpKernel { // All the slices are the same size: this is the size along the // split dimension. const int32_t slice_size = input_shape.dim_size(split_dim) / num_split; - auto slice_expr = *input_shape.get_filled_expression(split_dim) / num_split; + xla::DExpr slice_expr = + input_shape.get_filled_expression(split_dim) / num_split; // The vectors we will use to define the slice. The entry for the // split dimensions varies for each output. std::vector begin(input_shape.dims(), 0); std::vector limits(input_shape.dims()); - std::vector begin_expr(input_shape.dims(), xla::DynExpr::zero); - std::vector limits_expr(input_shape.dims()); + std::vector begin_expr; + std::vector limits_expr; + begin_expr.reserve(input_shape.dims()); + limits_expr.reserve(input_shape.dims()); std::vector strides(input_shape.dims(), 1); for (int i = 0; i < input_shape.dims(); ++i) { // Initially set up the limits to be the full size of the input: // the split dimension is filled in below. int64_t dim = input_shape.dim_size(i); limits[i] = dim; - limits_expr[i] = input_shape.get_filled_expression(i); + begin_expr.push_back(xla::DExpr::Const(0)); + limits_expr.push_back(input_shape.get_filled_expression(i)); } // Create each of the outputs. @@ -104,9 +108,8 @@ class SplitOp : public XlaOpKernel { begin[split_dim] = i * slice_size; limits[split_dim] = (i + 1) * slice_size; - begin_expr[split_dim] = i * *slice_expr; - limits_expr[split_dim] = (*xla::DynExpr::_(i + 1) * *slice_expr)->s(); - + begin_expr[split_dim] = i * slice_expr; + limits_expr[split_dim] = ((i + 1) * slice_expr).simplify(); ctx->SetOutput(i, xla::Slice(input, begin, limits, begin_expr, limits_expr, strides)); } @@ -217,20 +220,23 @@ class SplitVOp : public XlaOpKernel { auto dim_sizes = input_shape.dim_sizes(); std::vector limits(dim_sizes.begin(), dim_sizes.end()); std::vector strides(input_shape.dims(), 1); - std::vector begin_expr(input_shape.dims(), - xla::DynExpr::zero); + std::vector begin_expr(input_shape.dims(), xla::DExpr::Const(0)); auto input_exprs = input_shape.get_filled_expressions(); - std::vector limits_expr(input_exprs.begin(), - input_exprs.end()); + std::vector limits_expr; + limits_expr.reserve(input_exprs.size()); + for (auto expr : input_exprs) { + limits_expr.push_back(expr); + } for (int i = 0; i < num_split; ++i) { int slice_size = split_sizes[i]; - xla::DynExpr* slice_expr = xla::DynExpr::_(slice_size); + xla::DExpr slice_expr = xla::DExpr::Const(slice_size); // Slice out the ith split from the split dimension. limits[split_dim] = begin[split_dim] + slice_size; - limits_expr[split_dim] = (*begin_expr[split_dim] + *slice_expr)->s(); - ctx->SetOutput( - i, xla::Slice(input, begin, limits, begin_expr, limits_expr, strides)); + limits_expr[split_dim] = (begin_expr[split_dim] + slice_expr).simplify(); + ctx->SetOutput(i, + xla::Slice(input, begin, limits, begin_expr, limits_expr, + strides)); begin[split_dim] = limits[split_dim]; begin_expr[split_dim] = limits_expr[split_dim]; } diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index d23723f9e7fe90..9e3d5b6fa6aad5 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -51,7 +51,7 @@ xla::BitGeneratorTy GetBitGeneratorForDevice( std::tie(state, key) = xla::ScramblePhiloxKey(key); xla::XlaOp philox_state = xla::ConcatInDim( key.builder(), - {xla::Reshape(key, {1}, {xla::DynExpr::one}), state}, 0); + {xla::Reshape(key, {1}, {xla::DExpr::Const(1)}), state}, 0); xla::XlaOp result = xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX, philox_state, shape); return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1), @@ -421,7 +421,6 @@ class StatelessParameterizedTruncatedNormalOp : public XlaOpKernel { auto rng_dtype = MaybeConvertBF16ToF32(dtype_); xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); - auto bcasted_means = BroadcastTo(ctx->Input(2), shape.dim_sizes(), shape.get_filled_expressions()); OP_REQUIRES_OK(ctx, bcasted_means.status()); diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 7b328772790c63..bbe7f63a80a823 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -162,10 +162,10 @@ class StridedSliceOp : public XlaOpKernel { auto zero = xla::Zero(ctx->builder(), ctx->InputXlaType("begin")); xla::XlaOp begin_index, end_index; int64_t sparse_index = shape_spec.processing_to_sparse_mapping[i]; - xla::DynExpr* input_expr = input_xla_shape.expressions(i); + const xla::DExpr& input_expr = input_xla_shape.expressions(i); bool xla_input_is_dynamic = input_xla_shape.is_dynamic_dimension(i) || - (input_expr != nullptr && input_expr->is_dynamic()); + input_expr->is_dynamic(); xla::XlaOp dim_size; if (xla_input_is_dynamic) { dim_size = xla::GetDimensionSize(ctx->Input(0), i); @@ -256,8 +256,8 @@ class StridedSliceOp : public XlaOpKernel { absl::InlinedVector begin; absl::InlinedVector end; - absl::InlinedVector begin_expr; - absl::InlinedVector end_expr; + absl::InlinedVector begin_expr; + absl::InlinedVector end_expr; absl::InlinedVector strides; xla::Literal begin_literal, end_literal, strides_literal; @@ -307,7 +307,7 @@ class StridedSliceOp : public XlaOpKernel { ", output shape must be a compile-time constant")); absl::InlinedVector dimensions_to_reverse; absl::InlinedVector slice_begin, slice_end, slice_strides; - absl::InlinedVector slice_begin_expr, slice_end_expr; + absl::InlinedVector slice_begin_expr, slice_end_expr; for (int i = 0; i < begin.size(); ++i) { if (strides[i] > 0) { slice_begin.push_back(begin[i]); @@ -319,16 +319,18 @@ class StridedSliceOp : public XlaOpKernel { } else { // Negative stride: swap begin and end, add 1 because the interval // is semi-open, and mark the dimension to be reversed. - auto input_exprs = input_shape.get_filled_expressions(); + xla::DExpr input_expr = input_shape.get_filled_expression(i); slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1); slice_begin_expr.push_back( - (*input_exprs[i] - *begin_expr[i] - 1)->s()); + (input_expr - begin_expr[i] - xla::DExpr::Const(1)).simplify()); slice_end.push_back(std::max(input_shape.dim_size(i) - end[i] - 1, input_shape.dim_size(i) - begin[i] - 1)); slice_end_expr.push_back( (end[i] < begin[i]) - ? (*input_exprs[i] - *end_expr[i] - 1)->s() - : (*input_exprs[i] - *begin_expr[i] - 1)->s()); + ? (input_expr - end_expr[i] - xla::DExpr::Const(1)) + .simplify() + : (input_expr - begin_expr[i] - xla::DExpr::Const(1)) + .simplify()); slice_strides.push_back(-strides[i]); dimensions_to_reverse.push_back(i); } @@ -462,8 +464,8 @@ class StridedSliceGradOp : public XlaOpKernel { PartialTensorShape processing_shape, final_shape; absl::InlinedVector begin; absl::InlinedVector end; - absl::InlinedVector begin_expr; - absl::InlinedVector end_expr; + absl::InlinedVector begin_expr; + absl::InlinedVector end_expr; absl::InlinedVector strides; StridedSliceShapeSpec shape_spec; OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, @@ -487,17 +489,22 @@ class StridedSliceGradOp : public XlaOpKernel { VLOG(1) << "xla final_shape" << final_shape; VLOG(1) << "input_shape" << input_shape.DebugString(); auto input_sizes = input_shape.dim_sizes(); - auto input_exprs = input_shape.get_filled_expressions(); + std::vector input_exprs; + input_exprs.reserve(input_shape.dims()); + for (int64_t i = 0; i < input_shape.dims(); ++i) { + input_exprs.push_back(input_shape.get_filled_expression(i)); + } // For unknown output dim the bound of the output shape is input. Pad and // double the size of input shape to leave enough buffer to avoid OOB // dynamic update slice. auto input_sizes_padded = input_shape.dim_sizes(); - auto input_exprs_padded = input_shape.get_filled_expressions(); + auto input_exprs_padded = input_exprs; bool need_padding = false; for (int64_t i = 0; i < processing_shape.dims(); ++i) { if (processing_shape.dim_size(i) == -1) { input_sizes_padded[i] *= 2; - input_exprs_padded[i] = (2 * *input_exprs_padded[i])->s(); + input_exprs_padded[i] = + (xla::DExpr::Const(2) * input_exprs_padded[i]).simplify(); need_padding = true; } } @@ -547,8 +554,8 @@ class StridedSliceGradOp : public XlaOpKernel { // padding in the final result. std::vector strides(input_shape.dims(), 1); std::vector start_indices(input_shape.dims(), 0); - std::vector start_exprs(input_shape.dims(), - xla::DynExpr::zero); + std::vector start_exprs(input_shape.dims(), + xla::DExpr::Const(0)); grad = xla::Slice(grad, start_indices, input_sizes, start_exprs, input_exprs, strides); } @@ -558,8 +565,8 @@ class StridedSliceGradOp : public XlaOpKernel { TensorShape processing_shape, final_shape; absl::InlinedVector begin; absl::InlinedVector end; - absl::InlinedVector begin_expr; - absl::InlinedVector end_expr; + absl::InlinedVector begin_expr; + absl::InlinedVector end_expr; absl::InlinedVector strides; TensorShape input_shape; @@ -702,8 +709,8 @@ class StridedSliceAssignOp : public XlaOpKernel { TensorShape final_shape; absl::InlinedVector begin; absl::InlinedVector end; - absl::InlinedVector begin_expr; - absl::InlinedVector end_expr; + absl::InlinedVector begin_expr; + absl::InlinedVector end_expr; absl::InlinedVector strides; xla::Literal begin_literal, end_literal, strides_literal; @@ -760,7 +767,7 @@ class StridedSliceAssignOp : public XlaOpKernel { absl::InlinedVector dimensions_to_reverse; absl::InlinedVector slice_begin; absl::InlinedVector slice_dims; - absl::InlinedVector slice_exprs; + absl::InlinedVector slice_exprs; for (int i = 0; i < begin.size(); ++i) { // TODO(b/121179231): implement strides != 1 OP_REQUIRES( @@ -770,14 +777,14 @@ class StridedSliceAssignOp : public XlaOpKernel { slice_begin.push_back( xla::ConstantR0(ctx->builder(), begin[i])); slice_dims.push_back(end[i] - begin[i]); - slice_exprs.push_back(xla::DynExpr::_(end[i] - begin[i])); + slice_exprs.push_back(xla::DExpr::Const(end[i] - begin[i])); } else { // Negative stride: swap begin and end, add 1 because the interval // is semi-open, and mark the dimension to be reversed. slice_begin.push_back( xla::ConstantR0(ctx->builder(), end[i] + 1)); slice_dims.push_back(begin[i] - end[i]); - slice_exprs.push_back(xla::DynExpr::_(begin[i] - end[i])); + slice_exprs.push_back(xla::DExpr::Const(begin[i] - end[i])); dimensions_to_reverse.push_back(i); } } diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 25cd54b5fbc5a2..9bc5d0d95bfd8c 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -165,7 +165,7 @@ class TensorArrayOp : public XlaOpKernel { CHECK(element_shape_.AsTensorShape(&shape)); TensorShape ta_shape; ta_shape.AddDim(size); - ta_shape.AddExpression(xla::DynExpr::_(size)); + ta_shape.AddExpression(xla::DExpr::Const(size)); ta_shape.AppendShape(shape); xla::XlaOp zero = XlaHelpers::Zero(b, dtype_); value = xla::Broadcast(zero, ta_shape.dim_sizes(), @@ -279,7 +279,7 @@ class TensorArrayReadOp : public XlaOpKernel { auto slice_shape = ta_shape.dim_sizes(); auto slice_exprs = ta_shape.get_filled_expressions(); slice_shape[0] = 1LL; - slice_exprs[0] = xla::DynExpr::_(1LL); + slice_exprs[0] = xla::DExpr::Const(1LL); xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape, slice_exprs); @@ -477,9 +477,13 @@ class TensorArrayConcatOp : public XlaOpKernel { auto ta_dims = ta_shape.dim_sizes(); auto ta_exprs = ta_shape.get_filled_expressions(); std::vector shape(ta_dims.begin() + 1, ta_dims.end()); - std::vector exprs(ta_exprs.begin() + 1, ta_exprs.end()); + std::vector exprs; + exprs.reserve(ta_exprs.size() - 1); + for (auto it = ta_exprs.begin() + 1; it != ta_exprs.end(); ++it) { + exprs.push_back(*it); + } shape[0] *= ta_shape.dim_size(0); - exprs[0] = *ta_exprs[0] * *ta_shape.get_filled_expression(0); + exprs[0] = ta_exprs[0] * ta_shape.get_filled_expression(0); ctx->SetOutput(0, xla::Reshape(ta, shape, exprs)); Tensor lengths(DT_INT64, {ta_dims[0]}); @@ -535,7 +539,7 @@ class TensorArraySplitOp : public XlaOpKernel { TensorShape ta_shape; ta_shape.AddDim(resource->max_array_size()); - ta_shape.AddExpression(xla::DynExpr::_(resource->max_array_size())); + ta_shape.AddExpression(xla::DExpr::Const(resource->max_array_size())); ta_shape.AppendShape(elem_shape); OP_REQUIRES(ctx, lengths.size() == resource->max_array_size(), diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 8531d2b0cccd3e..92128c0dc19873 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -504,8 +504,8 @@ class TensorListConcatOp : public XlaOpKernel { xla::Shape element_shape = std::move(shape_or).value(); std::vector element_dims = xla::SpanToVector(element_shape.dimensions()); - std::vector element_exprs = - xla::SpanToVector(element_shape.expressions()); + std::vector element_exprs(element_shape.expressions().begin(), + element_shape.expressions().end()); OP_REQUIRES( ctx, element_dims.size() > 1, errors::Unimplemented("TensorList of scalars is not supported")); @@ -513,8 +513,8 @@ class TensorListConcatOp : public XlaOpKernel { int64_t tensor_lengths = element_dims[1]; std::vector new_dims = {num_elements * tensor_lengths}; - std::vector new_exprs = { - xla::DynExpr::_(num_elements * tensor_lengths)}; + std::vector new_exprs = { + xla::DExpr::Const(num_elements * tensor_lengths)}; for (int i = 2; i < element_dims.size(); i++) { new_dims.push_back(element_dims[i]); @@ -556,8 +556,8 @@ class TensorListSplitOp : public XlaOpKernel { xla::Shape element_shape = std::move(shape_or).value(); std::vector element_dims = xla::SpanToVector(element_shape.dimensions()); - std::vector element_exprs = - xla::SpanToVector(element_shape.expressions()); + std::vector element_exprs(element_shape.expressions().begin(), + element_shape.expressions().end()); OP_REQUIRES( ctx, !element_dims.empty(), errors::Unimplemented("Element dimensions have to be non-empty")); @@ -577,8 +577,8 @@ class TensorListSplitOp : public XlaOpKernel { ctx, element_dims[0] % length == 0, errors::Unimplemented("Buffer size has to be a multiple of length")); std::vector new_dims = {element_dims[0] / length, length}; - std::vector new_exprs = {*element_exprs[0] / length, - xla::DynExpr::_(length)}; + std::vector new_exprs = {element_exprs[0] / xla::DExpr::Const(length), + xla::DExpr::Const(length)}; for (int i = 1; i < element_dims.size(); i++) { new_dims.push_back(element_dims[i]); } diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc index 1771181e440f31..9f235e6994e7d9 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -389,10 +389,12 @@ absl::Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, std::vector element_part_dims = xla::SpanToVector(element_part_shape.dimensions()); element_part_dims.insert(element_part_dims.begin(), 1); - std::vector element_part_exprs = - xla::SpanToVector(element_part_shape.expressions()); - element_part_exprs.insert(element_part_exprs.begin(), - xla::DynExpr::one); + std::vector element_part_exprs; + element_part_exprs.reserve(element_part_shape.expressions().size() + 1); + element_part_exprs.push_back(xla::DExpr::Const(1)); + for (auto expr : element_part_shape.expressions()) { + element_part_exprs.push_back(expr); + } element_part = xla::Reshape(element_part, element_part_dims, element_part_exprs); @@ -411,9 +413,12 @@ absl::Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, std::vector element_dims = xla::SpanToVector(element_shape.dimensions()); element_dims.insert(element_dims.begin(), 1); - std::vector element_exprs = - xla::SpanToVector(element_shape.expressions()); - element_exprs.insert(element_exprs.begin(), xla::DynExpr::one); + std::vector element_exprs; + element_exprs.reserve(element_shape.expressions().size() + 1); + element_exprs.push_back(xla::DExpr::Const(1)); + for (auto expr : element_shape.expressions()) { + element_exprs.push_back(expr); + } xla::XlaOp update = xla::Reshape(element, element_dims, element_exprs); std::vector start_indices(element_shape.dimensions().size() + 1, @@ -463,16 +468,19 @@ absl::Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result, xla::SpanToVector(list_part_shape.dimensions()); slice_shape[0] = 1LL; - std::vector slice_exprs = - xla::SpanToVector(list_part_shape.expressions()); - slice_exprs[0] = xla::DynExpr::_(1LL); + std::vector slice_exprs; + slice_exprs.reserve(list_part_shape.expressions().size()); + for (auto expr : list_part_shape.expressions()) { + slice_exprs.push_back(expr); + } + slice_exprs[0] = xla::DExpr::Const(1LL); xla::XlaOp list_part = xla::GetTupleElement(list, i); xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape); slice_shape.erase(slice_shape.begin()); - element_result_parts.push_back( - xla::Reshape(read, slice_shape, slice_exprs)); + slice_exprs.erase(slice_exprs.begin()); + element_result_parts.push_back(xla::Reshape(read, slice_shape, slice_exprs)); list_result_parts.push_back(list_part); } list_result_parts.push_back(push_index); @@ -506,10 +514,12 @@ absl::Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index, std::vector element_dims = xla::SpanToVector(element_shape.dimensions()); element_dims.insert(element_dims.begin(), 1); - std::vector element_exprs = - xla::SpanToVector(element_shape.expressions()); - element_exprs.insert(element_exprs.begin(), xla::DynExpr::one); - + std::vector element_exprs; + element_exprs.reserve(element_shape.expressions().size() + 1); + element_exprs.push_back(xla::DExpr::Const(1)); + for (auto expr : element_shape.expressions()) { + element_exprs.push_back(expr); + } xla::XlaOp update = xla::Reshape(element, element_dims, element_exprs); std::vector start_indices(element_shape.dimensions().size() + 1, @@ -574,9 +584,12 @@ absl::Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, xla::SpanToVector(buffer_shape.dimensions()); slice_shape[0] = 1LL; - std::vector slice_exprs = - xla::SpanToVector(buffer_shape.expressions()); - slice_exprs[0] = xla::DynExpr::_(1LL); + std::vector slice_exprs; + slice_exprs.reserve(buffer_shape.expressions().size()); + for (auto expr : buffer_shape.expressions()) { + slice_exprs.push_back(expr); + } + slice_exprs[0] = xla::DExpr::Const(1LL); xla::XlaOp list_part = xla::GetTupleElement(list, 0); xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape); diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index b37aa02837eb68..da6ac93eedccd2 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -67,7 +67,7 @@ class TileOp : public XlaOpKernel { xla::ValueInferenceMode::kUpperBound)); std::vector output_dims(input_shape.dims()); - std::vector output_exprs(input_shape.dims()); + std::vector output_exprs(input_shape.dims()); auto expr_sizes = input_shape.get_filled_expressions(); @@ -77,7 +77,7 @@ class TileOp : public XlaOpKernel { "] >= 0, but got ", output_dims[i])); output_dims[i] = input_shape.dim_size(i) * multiples_bounds[i]; output_exprs[i] = - (*expr_sizes[i] * *xla::DynExpr::_(multiples_bounds[i]))->s(); + (expr_sizes[i] * xla::DExpr::Const(multiples_bounds[i])).simplify(); } std::vector multiples_are_dynamic; diff --git a/tensorflow/compiler/tf2xla/kernels/unique_op.cc b/tensorflow/compiler/tf2xla/kernels/unique_op.cc index fbe181d13ef547..f19278265b8ccd 100644 --- a/tensorflow/compiler/tf2xla/kernels/unique_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unique_op.cc @@ -84,7 +84,7 @@ class UniqueOpBase : public XlaOpKernel { // This is implemented as an hlo while loop. xla::XlaOp RollingSelectR1(XlaOpKernelContext* ctx, xla::XlaOp data, xla::XlaOp mask, int64_t size, - xla::DynExpr* expr) { + const xla::DExpr& expr) { xla::XlaComputation cond, body; xla::Shape r1_shape = xla::ShapeUtil::MakeShape(xla::S32, {size}); r1_shape.set_expression(0, expr); @@ -158,12 +158,12 @@ class UniqueOpBase : public XlaOpKernel { int64_t leading_size = aux_shape.dimensions(0); auto leading_expr = aux_shape.expressions(0); int64_t product = 1; - auto product_expr = xla::DynExpr::one; + auto product_expr = xla::DExpr::Const(1); for (int64_t i = 1; i < aux_shape.dimensions().size(); ++i) { product *= aux_shape.dimensions(i); - product_expr = *(product_expr) * *(aux_shape.expressions(i)); + product_expr = product_expr * aux_shape.expressions(i); } - product_expr = product_expr->s(); + product_expr = product_expr.simplify(); aux = xla::Reshape(aux, {leading_size, product}, {leading_expr, product_expr}); if (leading_size == 0) { @@ -215,10 +215,12 @@ class UniqueOpBase : public XlaOpKernel { auto permuted = xla::Gather(aux, perm, gather_dim_numbers, {1, product}); // Tail is everything except for first element. auto tail = xla::SliceInDim(permuted, 1, leading_size, - xla::DynExpr::one, leading_expr, 1, 0); + xla::DExpr::Const(1), leading_expr, 1, 0); // Init is everything except for last element. auto init = xla::SliceInDim(permuted, 0, leading_size - 1, - xla::DynExpr::zero, *leading_expr - 1, 1, 0); + xla::DExpr::Const(0), + (leading_expr - xla::DExpr::Const(1)).simplify(), + 1, 0); auto ne = xla::Compare(tail, init, xla::ComparisonDirection::kNe); auto reduce = xla::Reduce(ne, xla::ConstantR0(ctx->builder(), false), diff --git a/tensorflow/compiler/tf2xla/kernels/where_op.cc b/tensorflow/compiler/tf2xla/kernels/where_op.cc index 29bcf4fa0769ad..510086bc9caad0 100644 --- a/tensorflow/compiler/tf2xla/kernels/where_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/where_op.cc @@ -165,11 +165,11 @@ absl::StatusOr CompileWhereWithSort(XlaOpKernelContext* ctx) { xla::ShapeUtil::MakeShape(xla::S32, input_shape.dimensions()); int64_t flattened_size = xla::Product(iota_shape.dimensions()); - xla::DynExpr* flattened_expr = xla::DynExpr::one; + xla::DExpr flattened_expr = xla::DExpr::Const(1); for (auto e : iota_shape.expressions()){ - flattened_expr = *flattened_expr * *e; + flattened_expr = flattened_expr * e; } - flattened_expr = flattened_expr->s(); + flattened_expr = flattened_expr.simplify(); XlaOp reshaped_condition = xla::Reshape(condition, {flattened_size}, {flattened_expr}); XlaOp zeros = xla::ZerosLike(reshaped_condition); @@ -193,7 +193,7 @@ absl::StatusOr CompileWhereWithSort(XlaOpKernelContext* ctx) { for (int64_t i = 0; i < iota_shape.dimensions_size(); ++i) { XlaOp index_single_dim = xla::GetTupleElement(sorted, i + 1); to_concat.push_back(xla::Reshape(index_single_dim, {flattened_size, 1}, - {flattened_expr, xla::DynExpr::one})); + {flattened_expr, xla::DExpr::Const(1)})); } XlaOp result = xla::ConcatInDim(ctx->builder(), to_concat, 1); @@ -221,11 +221,11 @@ absl::StatusOr CompileWhereWithPrefixSum(XlaOpKernelContext* ctx) { TF_ASSIGN_OR_RETURN(xla::Shape input_shape, b->GetShape(condition)); int64_t flattened_size = xla::Product(input_shape.dimensions()); - xla::DynExpr* flattened_expr = xla::DynExpr::one; + xla::DExpr flattened_expr = xla::DExpr::Const(1); for (auto e : input_shape.expressions()) { - flattened_expr = *flattened_expr * *e; + flattened_expr = flattened_expr * e; } - flattened_expr = flattened_expr->s(); + flattened_expr = flattened_expr.simplify(); XlaOp reshaped_condition = xla::Reshape(condition, {flattened_size}, {flattened_expr}); XlaOp zeros = xla::ZerosLike(reshaped_condition); @@ -267,7 +267,7 @@ absl::StatusOr CompileWhereWithPrefixSum(XlaOpKernelContext* ctx) { /*on_true=*/prefix_sum - xla::One(b, S32), /*on_false=*/oob_idx); out_idxs = xla::Reshape(out_idxs, {flattened_size, 1}, - {flattened_expr, xla::DynExpr::one}); + {flattened_expr, xla::DExpr::Const(1)}); // tf.where returns an array of multidimensional indices where the condition // is true. For example: @@ -295,7 +295,7 @@ absl::StatusOr CompileWhereWithPrefixSum(XlaOpKernelContext* ctx) { for (int64_t axis = 0; axis < iota_shape.dimensions_size(); ++axis) { iotas_to_concat.push_back( xla::Reshape(xla::Iota(b, iota_shape, axis), {flattened_size, 1}, - {flattened_expr, xla::DynExpr::one})); + {flattened_expr, xla::DExpr::Const(1)})); } XlaOp iotas = xla::ConcatInDim(b, iotas_to_concat, /*dimension=*/1); diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc index c866c4429d4818..07d09559b15d35 100644 --- a/tensorflow/compiler/tf2xla/lib/broadcast.cc +++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc @@ -34,7 +34,7 @@ namespace tensorflow { absl::StatusOr BroadcastTo( xla::XlaOp input, absl::Span output_dims, - absl::Span output_exprs) { + absl::Span output_exprs) { return xla::BroadcastTo(input, output_dims, output_exprs); } diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.h b/tensorflow/compiler/tf2xla/lib/broadcast.h index ee56975b664f7b..b5903775679b30 100644 --- a/tensorflow/compiler/tf2xla/lib/broadcast.h +++ b/tensorflow/compiler/tf2xla/lib/broadcast.h @@ -31,7 +31,7 @@ namespace tensorflow { // TODO(cheshire): Call the underlying function directly. absl::StatusOr BroadcastTo( xla::XlaOp input, absl::Span output_dims, - absl::Span output_exprs = {}); + absl::Span output_exprs = {}); // Forwards to xla::BroadcastOpsToSame. absl::Status BroadcastOpsToSame(xla::XlaOp* lhs, xla::XlaOp* rhs); diff --git a/tensorflow/compiler/tf2xla/lib/data_format.cc b/tensorflow/compiler/tf2xla/lib/data_format.cc index 38116fddb200af..5821a2752cd393 100644 --- a/tensorflow/compiler/tf2xla/lib/data_format.cc +++ b/tensorflow/compiler/tf2xla/lib/data_format.cc @@ -53,9 +53,10 @@ absl::StatusOr Contract(xla::XlaOp input, int64_t dim) { input_shape.dimensions().end() - 1); contracted_shape[dim] *= 4; - std::vector contracted_exprs( + std::vector contracted_exprs( input_shape.expressions().begin(), input_shape.expressions().end() - 1); - contracted_exprs[dim] = (*(contracted_exprs[dim]) * *xla::DynExpr::_(4))->s(); + contracted_exprs[dim] = + (contracted_exprs[dim] * xla::DExpr::Const(4)).simplify(); return xla::Reshape(xla::Transpose(input, permutation), contracted_shape, contracted_exprs); diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index c5ce4b52ad8cd6..f9ee621bbe5a6f 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/jit/flags.h" #include "xla/service/shape_inference.h" #include "xla/shape.h" #include "xla/xla_data.pb.h" @@ -1144,15 +1143,13 @@ xla::Shape GetShape(shape_inference::ShapeHandle shape_handle, } std::vector dims; std::vector dynamic_dims; - std::vector expressions; - MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + std::vector expressions; for (int i = 0, rank = c->Rank(shape_handle); i < rank; ++i) { bool is_dynamic = !c->ValueKnown(c->Dim(shape_handle, i)); int dynamic_multiplier = c->DynamicRatio(c->Dim(shape_handle, i)); dynamic_dims.push_back(is_dynamic); - if (flags->tf_xla_enable_dynamic_sizes) { - expressions.push_back(dynamic_multiplier * *xla::DynExpr::V(1)); - } + expressions.push_back(xla::DExpr::Const(dynamic_multiplier) * + xla::DExpr::Var(1)); dims.push_back(is_dynamic ? xla::Shape::kUnboundedSize : c->Value(c->Dim(shape_handle, i))); } @@ -1161,9 +1158,7 @@ xla::Shape GetShape(shape_inference::ShapeHandle shape_handle, xla::PrimitiveType::S64, dims, absl::InlinedVector(dynamic_dims.begin(), dynamic_dims.end())); - if (flags->tf_xla_enable_dynamic_sizes) { - sh.set_expressions(expressions); - } + sh.set_expressions(expressions); return sh; } diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index 54384ae5531636..7e04acf7ecb503 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -19,7 +19,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" -#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "xla/layout_util.h" #include "xla/shape_util.h" @@ -101,12 +100,9 @@ absl::Status XLAShapeToTensorShape(const xla::Shape& shape, for (int i = 0; i < shape.dimensions().size(); ++i) { TF_RETURN_IF_ERROR(tensor_shape->AddDimWithStatus(shape.dimensions(i))); } - MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); - if (flags->tf_xla_enable_dynamic_sizes) { - std::vector bexprs(shape.expressions().begin(), - shape.expressions().end()); - tensor_shape->set_expressions(bexprs); - } + std::vector dexprs(shape.expressions().begin(), + shape.expressions().end()); + tensor_shape->set_expressions(std::move(dexprs)); return absl::OkStatus(); } @@ -174,12 +170,8 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, } int rank = tensor_shape.dims(); std::vector dimensions(rank); + std::vector expressions(rank); std::vector layout(rank); - MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); - std::vector expressions; - if (flags->tf_xla_enable_dynamic_sizes) { - expressions.resize(rank); - } for (int d = 0; d < rank; ++d) { dimensions[d] = tensor_shape.dim_size(d); if (dimensions[d] < 0) { @@ -187,17 +179,13 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, "shape; returning unknown sentinel value"; return xla::ShapeUtil::MakeShapeWithDenseLayout(type, {0}, {0}); } - if (flags->tf_xla_enable_dynamic_sizes) { - expressions[d] = tensor_shape.get_filled_expression(d); - } + expressions[d] = tensor_shape.get_filled_expression(d); } // XLA uses minor-to-major; Tensorflow uses major-to-minor. std::iota(layout.rbegin(), layout.rend(), 0); xla::Shape result = xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout); - if (flags->tf_xla_enable_dynamic_sizes) { - result.set_expressions(expressions); - } + result.set_expressions(expressions); return result; } @@ -225,14 +213,11 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, int rank = tensor_shape.dims(); std::vector dimensions(rank); std::vector layout(rank); - MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); - std::vector expressions; - if (flags->tf_xla_enable_dynamic_sizes) { - expressions = tensor_shape.get_filled_expressions(); - } + std::vector expressions(rank); for (int d = 0; d < rank; ++d) { dimensions[d] = tensor_shape.dim_size(d); + expressions[d] = tensor_shape.get_filled_expression(d); } // XLA uses minor-to-major; Tensorflow uses major-to-minor. @@ -240,9 +225,7 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, auto shape = xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout); - if (flags->tf_xla_enable_dynamic_sizes) { - shape.set_expressions(expressions); - } + shape.set_expressions(expressions); return shape; } diff --git a/tensorflow/compiler/tf2xla/xla_argument.h b/tensorflow/compiler/tf2xla/xla_argument.h index 2694cd275be118..bbdc7df77e1b75 100644 --- a/tensorflow/compiler/tf2xla/xla_argument.h +++ b/tensorflow/compiler/tf2xla/xla_argument.h @@ -80,7 +80,7 @@ struct XlaArgument { // When non-negative, marks the single constant element that later passes may // reinterpret as coming from a dynamic expression instead of the literal. int64_t dynamic_constant_index = -1; - xla::DynExpr* dynamic_constant_expr = nullptr; + std::optional dynamic_constant_expr; // The upper bounds of the value. std::optional value_bound; @@ -121,7 +121,7 @@ struct XlaArgument { // Returns the dimension sizes for either TensorShape or xla::Shape. std::vector DimensionSizes() const; - std::vector DimensionExpressions() const; + std::vector DimensionExpressions() const; absl::InlinedVector DimensionSizesAsInlinedVector() const; // Returns the human-readable string for either TensorShape or xla::Shape. diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 87b64ae7a61bfb..3e33fdecccc629 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -511,11 +511,16 @@ std::vector XlaCompiler::Argument::DimensionSizes() const { } } -std::vector XlaCompiler::Argument::DimensionExpressions() const { +std::vector XlaCompiler::Argument::DimensionExpressions() const { if (absl::holds_alternative(shape)) { return std::get(shape).get_filled_expressions(); } else { - return xla::SpanToVector(std::get(shape).expressions()); + std::vector expressions; + expressions.reserve(std::get(shape).expressions().size()); + for (const auto& expr : std::get(shape).expressions()) { + expressions.push_back(expr); + } + return expressions; } } @@ -1121,7 +1126,9 @@ absl::Status XlaCompiler::BuildArguments( arg_expression = XlaExpression::Constant(arg.constant_value); if (arg.dynamic_constant_index >= 0) { arg_expression.set_dynamic_constant_index(arg.dynamic_constant_index); - arg_expression.set_dynamic_constant_expr(arg.dynamic_constant_expr); + if (arg.dynamic_constant_expr.has_value()) { + arg_expression.set_dynamic_constant_expr(*arg.dynamic_constant_expr); + } } break; case XlaCompiler::Argument::kInvalid: diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc index e62a115b69eb82..a6485b593e2591 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.cc +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -99,7 +99,7 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const { TF_RETURN_IF_ERROR( HostTensorToBorrowingLiteral(*constant_value_, &literal)); if (!dynamic_constant_index_.has_value() || - dynamic_constant_expr_ == nullptr) { + !dynamic_constant_expr_.has_value()) { return xla::ConstantLiteral(builder, literal); } diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h index 173e3725fb5de1..a9c20705e6bffc 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.h +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -118,9 +118,11 @@ class XlaExpression { void set_dynamic_constant_index(int64_t index) { dynamic_constant_index_ = index; } - xla::DynExpr* dynamic_constant_expr() const { return dynamic_constant_expr_; } - void set_dynamic_constant_expr(xla::DynExpr* expr) { - dynamic_constant_expr_ = expr; + const std::optional& dynamic_constant_expr() const { + return dynamic_constant_expr_; + } + void set_dynamic_constant_expr(xla::DExpr expr) { + dynamic_constant_expr_ = std::move(expr); } XlaResource* resource() const { return resource_; } @@ -178,7 +180,7 @@ class XlaExpression { // For constant expressions, marks a single element that later passes may // reinterpret as coming from a dynamic expression instead of the literal. std::optional dynamic_constant_index_; - xla::DynExpr* dynamic_constant_expr_ = nullptr; + std::optional dynamic_constant_expr_; // The resource, if kind_ == kResource. Not owned. XlaResource* resource_ = nullptr; diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc index 4ef579655112ee..54013ddb603381 100644 --- a/tensorflow/core/framework/tensor_shape.cc +++ b/tensorflow/core/framework/tensor_shape.cc @@ -16,14 +16,15 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/bounds_check.h" -#include "tensorflow/core/framework/tensor_shape_expr.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/tensor_shape_expr.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/util/overflow.h" +#include "xla/printer.h" namespace tensorflow { @@ -31,105 +32,92 @@ namespace { const bool kTensorShapeExpressionsEnabled = TensorShapeExpressionsEnabled(); -} // namespace - -xla::DynExpr* ExprFromProto(const ExpressionProto& proto) { +xla::DExpr DExprFromProto(const ExpressionProto& proto) { switch (proto.node_type_case()) { case ExpressionProto::kConstantValue: - return xla::DynExpr::_(proto.constant_value()); - + return xla::DExpr::Const(proto.constant_value()); case ExpressionProto::kVariableId: - return xla::DynExpr::V(proto.variable_id()); - + return xla::DExpr::Var(proto.variable_id()); case ExpressionProto::kAddNode: { const auto& add = proto.add_node(); - return *ExprFromProto(add.lhs()) + *ExprFromProto(add.rhs()); + return DExprFromProto(add.lhs()) + DExprFromProto(add.rhs()); } - case ExpressionProto::kSubNode: { const auto& sub = proto.sub_node(); - return *ExprFromProto(sub.lhs()) - *ExprFromProto(sub.rhs()); + return DExprFromProto(sub.lhs()) - DExprFromProto(sub.rhs()); } - case ExpressionProto::kMulNode: { const auto& mul = proto.mul_node(); - return *ExprFromProto(mul.lhs()) * *ExprFromProto(mul.rhs()); + return DExprFromProto(mul.lhs()) * DExprFromProto(mul.rhs()); } - case ExpressionProto::kDivNode: { const auto& div = proto.div_node(); - return *ExprFromProto(div.lhs()) / *ExprFromProto(div.rhs()); + return DExprFromProto(div.lhs()) / DExprFromProto(div.rhs()); } - case ExpressionProto::NODE_TYPE_NOT_SET: default: - return nullptr; + return xla::DExpr::Unknown(); } } -void ExprToProto(xla::DynExpr* expr, ExpressionProto* proto) { - auto e = expr->s(); - if (xla::Constant* c = dynamic_cast(e)) { - proto->set_constant_value(c->get_val()); - } else if (xla::Variable* v = dynamic_cast(e)) { - proto->set_variable_id(v->get_id()); - } else if (xla::Add* a = dynamic_cast(e)) { - auto* add_msg = proto->mutable_add_node(); - ExprToProto(a->get_lhs(), add_msg->mutable_lhs()); - ExprToProto(a->get_rhs(), add_msg->mutable_rhs()); - } else if (xla::Mul* m = dynamic_cast(e)) { - auto* mul_msg = proto->mutable_mul_node(); - ExprToProto(m->get_lhs(), mul_msg->mutable_lhs()); - ExprToProto(m->get_rhs(), mul_msg->mutable_rhs()); - } else if (xla::Sub* s = dynamic_cast(e)) { - auto* sub_msg = proto->mutable_sub_node(); - ExprToProto(s->get_lhs(), sub_msg->mutable_lhs()); - ExprToProto(s->get_rhs(), sub_msg->mutable_rhs()); - } else if (xla::Div* d = dynamic_cast(e)) { - auto* div_msg = proto->mutable_div_node(); - ExprToProto(d->get_lhs(), div_msg->mutable_lhs()); - ExprToProto(d->get_rhs(), div_msg->mutable_rhs()); +void ExprToProto(const xla::DExpr& expr, ExpressionProto* proto) { + if (!expr) return; + switch (expr.kind()) { + case xla::DExpr::Kind::kUnknown: + return; + case xla::DExpr::Kind::kConstant: + proto->set_constant_value(expr->get_val()); + return; + case xla::DExpr::Kind::kVariable: + proto->set_variable_id( + static_cast(*expr.get()).get_id()); + return; + case xla::DExpr::Kind::kAdd: { + auto* add = proto->mutable_add_node(); + const auto& node = static_cast(*expr.get()); + ExprToProto(xla::DExpr(node.get_lhs()->clone()), + add->mutable_lhs()); + ExprToProto(xla::DExpr(node.get_rhs()->clone()), + add->mutable_rhs()); + return; + } + case xla::DExpr::Kind::kSub: { + auto* sub = proto->mutable_sub_node(); + const auto& node = static_cast(*expr.get()); + ExprToProto(xla::DExpr(node.get_lhs()->clone()), + sub->mutable_lhs()); + ExprToProto(xla::DExpr(node.get_rhs()->clone()), + sub->mutable_rhs()); + return; + } + case xla::DExpr::Kind::kMul: { + auto* mul = proto->mutable_mul_node(); + const auto& node = static_cast(*expr.get()); + ExprToProto(xla::DExpr(node.get_lhs()->clone()), + mul->mutable_lhs()); + ExprToProto(xla::DExpr(node.get_rhs()->clone()), + mul->mutable_rhs()); + return; + } + case xla::DExpr::Kind::kDiv: { + auto* div = proto->mutable_div_node(); + const auto& node = static_cast(*expr.get()); + ExprToProto(xla::DExpr(node.get_lhs()->clone()), + div->mutable_lhs()); + ExprToProto(xla::DExpr(node.get_rhs()->clone()), + div->mutable_rhs()); + return; + } } } -// Independent helper function to handle the recursion -void BuildExprString(xla::DynExpr* e, std::ostringstream& oss) { - if (xla::Constant* c = dynamic_cast(e)) { - oss << c->get_val(); - } else if (xla::Variable* v = dynamic_cast(e)) { - char letter = 'A' + (v->get_id() - 1); - oss << letter; - } else if (xla::Add* a = dynamic_cast(e)) { - oss << "("; - BuildExprString(a->get_lhs(), oss); - oss << " + "; - BuildExprString(a->get_rhs(), oss); - oss << ")"; - } else if (xla::Mul* m = dynamic_cast(e)) { - oss << "("; - BuildExprString(m->get_lhs(), oss); - oss << " * "; - BuildExprString(m->get_rhs(), oss); - oss << ")"; - } else if (xla::Sub* s = dynamic_cast(e)) { - oss << "("; - BuildExprString(s->get_lhs(), oss); - oss << " - "; - BuildExprString(s->get_rhs(), oss); - oss << ")"; - } else if (xla::Div* d = dynamic_cast(e)) { - oss << "("; - BuildExprString(d->get_lhs(), oss); - oss << " / "; - BuildExprString(d->get_rhs(), oss); - oss << ")"; - } -} +} // namespace -std::string ExprToString(xla::DynExpr* e) { - std::ostringstream oss; - BuildExprString(e, oss); - return oss.str(); +std::string ExprToString(const xla::DExpr& e) { + if (!e && !e.is_unknown()) return ""; + xla::StringPrinter printer; + e->print(&printer); + return std::move(printer).ToString(); } // TensorShape and PartialTensorShape should have no fields beyond @@ -260,7 +248,7 @@ TensorShapeBase::TensorShapeBase(const TensorShapeProto& proto) { } if (kTensorShapeExpressionsEnabled) { for (const auto& e : proto.expressions()) { - AddExpression(ExprFromProto(e)); + AddExpression(DExprFromProto(e)); } } } @@ -304,7 +292,7 @@ absl::Status TensorShapeBase::BuildTensorShapeBase( } if (kTensorShapeExpressionsEnabled) { for (const auto& e : proto.expressions()) { - out->AddExpression(ExprFromProto(e)); + out->AddExpression(DExprFromProto(e)); } } } @@ -491,41 +479,38 @@ void TensorShapeRep::Clear() { set_data_type(DT_INVALID); } -void TensorShapeRep::set_expression(int d, xla::DynExpr* expr) { +void TensorShapeRep::set_expression(int d, xla::DExpr expr) { if (!kTensorShapeExpressionsEnabled) { expressions_.clear(); return; } if (expressions_.size() <= static_cast(d)) { - expressions_.resize(d + 1, nullptr); + expressions_.resize(d + 1); } - expressions_[d] = expr; + expressions_[d] = std::move(expr); } -void TensorShapeRep::AddExpression(xla::DynExpr* expr) { +void TensorShapeRep::AddExpression(xla::DExpr expr) { if (!kTensorShapeExpressionsEnabled) { return; } CHECK_LT(expressions_.size(), ndims_byte()); - expressions_.push_back(expr); + expressions_.push_back(std::move(expr)); } -void TensorShapeRep::set_expressions(std::vector exprs) { +void TensorShapeRep::set_expressions(std::vector exprs) { if (!kTensorShapeExpressionsEnabled) { expressions_.clear(); return; } - while (!exprs.empty() && exprs.back() == nullptr) { - exprs.pop_back(); - } - expressions_ = exprs; + expressions_ = std::move(exprs); } void TensorShapeRep::ClearAllButDataType() { - expressions_.clear(); if (tag() == REP_OUT_OF_LINE) { delete as64()->dims_; } + expressions_.clear(); set_tag(REP16); set_ndims_byte(0); // Leaves data_type alone @@ -735,9 +720,7 @@ template void TensorShapeBase::set_dim(int d, int64_t size) { CHECK_GE(d, 0); CHECK_LT(d, dims()); - if (d < expressions_.size() && expressions_[d] != nullptr) { - set_expression(d, xla::DynExpr::_(size)); - } + if (get_expressions().size() > d) set_expression(d, xla::DExpr::Const(size)); if (!kIsPartial) { CHECK_GE(size, 0); } @@ -799,9 +782,7 @@ absl::Status TensorShapeBase::SetDimWithStatus(int d, int64_t size) { } } - if (d < expressions_.size() && expressions_[d] != nullptr) { - set_expression(d, xla::DynExpr::_(size)); - } + if (get_expressions().size() > d) set_expression(d, xla::DExpr::Const(size)); return RecomputeNumElements(); } @@ -817,7 +798,8 @@ void TensorShapeBase::RemoveDimRange(int begin, int end) { if (begin >= end) return; absl::InlinedVector vals; AppendTo(*this, &vals); - std::vector new_exprs = get_expressions(); + std::vector new_exprs(get_expressions().begin(), + get_expressions().end()); if (begin < static_cast(new_exprs.size())) { int64_t expr_end = end; if (expr_end > static_cast(new_exprs.size())) { @@ -877,7 +859,8 @@ absl::Status TensorShapeBase::RemoveDimRangeWithStatus(int begin, absl::InlinedVector vals; AppendTo(*this, &vals); - std::vector new_exprs = get_expressions(); + std::vector new_exprs(get_expressions().begin(), + get_expressions().end()); if (begin < static_cast(new_exprs.size())) { int64_t expr_end = end; @@ -926,11 +909,9 @@ void TensorShapeBase::AsProto(TensorShapeProto* proto) const { proto->add_dim()->set_size(dim_size(i)); } if (kTensorShapeExpressionsEnabled) { - for (int i = 0; i < expressions_.size(); ++i) { + for (int i = 0; i < get_expressions().size(); i++) { ExpressionProto* eproto = proto->add_expressions(); - if (expressions_[i] != nullptr) { - ExprToProto(expressions_[i], eproto); - } + ExprToProto(get_expression(i), eproto); } } } @@ -966,10 +947,9 @@ string TensorShapeRep::DebugString() const { } else { strings::StrAppend(&s, dim); } - if (kTensorShapeExpressionsEnabled && i < expressions_.size() && - expressions_[i] != nullptr) { + if (shape.get_expression(i)) { strings::StrAppend(&s, "<"); - strings::StrAppend(&s, ExprToString(expressions_[i])); + strings::StrAppend(&s, ExprToString(shape.get_expression(i))); strings::StrAppend(&s, ">"); } } @@ -1000,7 +980,7 @@ string TensorShapeRep::DebugString(const TensorShapeProto& proto) { first = true; for (const auto& e : proto.expressions()) { if (!first) strings::StrAppend(&s, ","); - auto exp = ExprFromProto(e); + auto exp = DExprFromProto(e); strings::StrAppend(&s, ExprToString(exp)); first = false; } @@ -1169,7 +1149,8 @@ absl::Status PartialTensorShape::MergeWith(const PartialTensorShape& shape, return s; } } - result->set_expressions(shape.get_expressions()); + result->set_expressions(std::vector(shape.get_expressions().begin(), + shape.get_expressions().end())); return absl::OkStatus(); } diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h index 834c630a6dfa82..5507cfac636fa0 100644 --- a/tensorflow/core/framework/tensor_shape.h +++ b/tensorflow/core/framework/tensor_shape.h @@ -75,25 +75,25 @@ class TensorShapeRep { std::string DebugString() const; static std::string DebugString(const TensorShapeProto& proto); - void set_expression(int d, xla::DynExpr* expr); + void set_expression(int d, xla::DExpr expr); - void AddExpression(xla::DynExpr* expr); + void AddExpression(xla::DExpr expr); // Set the array of dynamic multipliers. - void set_expressions(std::vector exprs); + void set_expressions(std::vector exprs); // Get the array of dynamic multipliers. - std::vector get_expressions() const { + absl::Span get_expressions() const { return expressions_; } // Get the array of dynamic multipliers, filling missing entries with // constant expressions derived from the concrete dimensions. - std::vector get_filled_expressions() const { + std::vector get_filled_expressions() const { if (ndims_byte() == kUnknownRank) { return {}; } - std::vector exprs(ndims_byte()); + std::vector exprs(ndims_byte()); for (int i = 0; i < ndims_byte(); ++i) { exprs[i] = get_filled_expression(i); } @@ -102,33 +102,32 @@ class TensorShapeRep { // Return the multiplier for a specific dynamic dimension. // -1 if the dimension is not dynamic. - xla::DynExpr* get_expression(int64_t dimension) const { - if (dimension < 0) return missing_expression(); + const xla::DExpr& get_expression(int64_t dimension) const { + static const xla::DExpr kMissingExpression = xla::DExpr::Unknown(); + if (dimension < 0) return kMissingExpression; const size_t dim = static_cast(dimension); if (dim >= expressions_.size()) { - return missing_expression(); + return kMissingExpression; } - return expressions_[dim] != nullptr ? expressions_[dim] - : missing_expression(); + return expressions_[dim]; } - // Return the multiplier for a specific dynamic dimension, materializing a - // constant expression for concrete dimensions without a stored expression. - xla::DynExpr* get_filled_expression(int64_t dimension) const { - if (dimension < 0) return missing_expression(); + xla::DExpr get_filled_expression(int64_t dimension) const { + if (dimension < 0) return xla::DExpr::Unknown(); const size_t dim = static_cast(dimension); - if (dim < expressions_.size() && expressions_[dim] != nullptr) { + if (dim < expressions_.size() && expressions_[dim]) { return expressions_[dim]; } if (ndims_byte() == kUnknownRank || dim >= ndims_byte()) { - return missing_expression(); + return xla::DExpr::Unknown(); } + return constant_expression_for_dim(dim); } protected: - std::vector expressions_; + std::vector expressions_; // Constructable only via TensorShapeBase TensorShapeRep() = default; @@ -199,12 +198,7 @@ class TensorShapeRep { void set_num_elements(int64_t n) { num_elements_ = n; } private: - static xla::DynExpr* missing_expression() { - static xla::DynExpr* const missing = xla::DynExpr::_(-999); - return missing; - } - - xla::DynExpr* constant_expression_for_dim(size_t dim) const { + xla::DExpr constant_expression_for_dim(size_t dim) const { int64_t dim_value = -1; if (tag() == REP16) { uint16 raw_dim = as16()->dims_[dim]; @@ -215,7 +209,7 @@ class TensorShapeRep { } else { dim_value = (*as64()->dims_)[dim]; } - return xla::DynExpr::_(dim_value); + return xla::DExpr::Const(dim_value); } void DestructorOutOfLine(); diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 4d880061828909..181567cc0fc888 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -2040,9 +2040,9 @@ class SymbolicShapeRefiner { } bool changed = false; + const bool force_fresh_unknown_dim = node->op() == "Reshape"; std::vector dims; dims.reserve(ic->Rank(s)); - const bool is_reshape = node->op() == "Reshape"; for (int d = 0; d < ic->Rank(s); ++d) { DimensionHandle dim = ic->Dim(s, d); const int64_t v = ic->Value(dim); @@ -2053,10 +2053,7 @@ class SymbolicShapeRefiner { } // If already tagged with expr, keep it. auto* dim_expr = ic->GetDimExpr(dim); - const bool refresh_reshape_expr = - is_reshape && dim_expr != nullptr && - dim_expr->kind() != DimExpr::Kind::kVariable; - if (dim_expr != nullptr && !refresh_reshape_expr) { + if (dim_expr != nullptr && !force_fresh_unknown_dim) { dims.push_back(dim); continue; } diff --git a/tensorflow/core/kernels/padding_fifo_queue.cc b/tensorflow/core/kernels/padding_fifo_queue.cc index 84cab8bf24230c..ed669f60550501 100644 --- a/tensorflow/core/kernels/padding_fifo_queue.cc +++ b/tensorflow/core/kernels/padding_fifo_queue.cc @@ -402,9 +402,9 @@ std::vector PaddingFIFOQueue::ConvertShapesPartialDimensionsToZero( TensorShape& shape = shapes[i]; for (int d = 0; d < partial.dims(); ++d) { shape.AddDim(partial.dim_size(d) < 0 ? 0 : partial.dim_size(d)); - xla::DynExpr* expr = partial.get_filled_expression(d); - if (expr != nullptr && expr->is_constant() && expr->get_val() < 0) { - expr = xla::DynExpr::zero; + xla::DExpr expr = partial.get_filled_expression(d); + if (expr && expr->is_constant() && expr->get_val() < 0) { + expr = xla::DExpr::Const(0); } shape.AddExpression(expr); } diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc index 8dc2b49e87e09a..7a2cf9d55922af 100644 --- a/tensorflow/core/kernels/strided_slice_op.cc +++ b/tensorflow/core/kernels/strided_slice_op.cc @@ -309,8 +309,8 @@ class StridedSliceAssignOp : public OpKernel { bool is_simple_slice = true; absl::InlinedVector begin; absl::InlinedVector end; - absl::InlinedVector begin_expr; - absl::InlinedVector end_expr; + absl::InlinedVector begin_expr; + absl::InlinedVector end_expr; absl::InlinedVector strides; Tensor* old_lhs = nullptr; diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc index 9b495591f38d61..a2bb25f9c2e923 100644 --- a/tensorflow/core/util/strided_slice_op.cc +++ b/tensorflow/core/util/strided_slice_op.cc @@ -57,8 +57,8 @@ struct StridedSliceDenseSpec { bool end_valid; absl::InlinedVector& begin; absl::InlinedVector& end; - absl::InlinedVector& begin_expr; - absl::InlinedVector& end_expr; + absl::InlinedVector& begin_expr; + absl::InlinedVector& end_expr; absl::InlinedVector& strides; // This vector helps construct the final shape of the slice. // The final tensor is reduced in rank whenever a single index e.g. foo[3] @@ -134,7 +134,7 @@ static absl::Status BuildDenseSpec(const StridedSliceSparseSpec& sparse, // new_axis' aren't real axis so you have to skip dense->begin[full_index] = dense->end[full_index] = 0; dense->begin_expr[full_index] = dense->end_expr[full_index] = - xla::DynExpr::zero; + xla::DExpr::Const(0); dense->strides[full_index] = 1; dense->begin_mask |= (1 << full_index); dense->end_mask |= (1 << full_index); @@ -159,12 +159,12 @@ static absl::Status BuildDenseSpec(const StridedSliceSparseSpec& sparse, if (begin_flat != nullptr) { dense->begin[full_index] = internal::SubtleMustCopy(begin_flat[i]); dense->begin_expr[full_index] = - xla::DynExpr::_(dense->begin[full_index]); + xla::DExpr::Const(dense->begin[full_index]); } if (end_flat != nullptr) { dense->end[full_index] = internal::SubtleMustCopy(end_flat[i]); dense->end_expr[full_index] = - xla::DynExpr::_(dense->end[full_index]); + xla::DExpr::Const(dense->end[full_index]); } dense->strides[full_index] = internal::SubtleMustCopy(strides_flat[i]); @@ -205,22 +205,22 @@ absl::Status ValidateStridedSliceOp( absl::InlinedVector* begin, absl::InlinedVector* end, absl::InlinedVector* strides, - absl::InlinedVector* begin_expr, - absl::InlinedVector* end_expr, + absl::InlinedVector* begin_expr, + absl::InlinedVector* end_expr, StridedSliceShapeSpec* shape_spec) { - absl::InlinedVector b; - absl::InlinedVector e; + absl::InlinedVector b; + absl::InlinedVector e; // HACK if (begin_expr == nullptr) { for (int i : *begin) { - b.push_back(xla::DynExpr::_(i)); + b.push_back(xla::DExpr::Const(i)); } begin_expr = &b; } if (end_expr == nullptr) { for (int i : *end) { - e.push_back(xla::DynExpr::_(i)); + e.push_back(xla::DExpr::Const(i)); } end_expr = &e; } @@ -337,8 +337,8 @@ absl::Status ValidateStridedSliceOp( int64_t& stride_i = (*strides)[i]; int64_t dim_i = input_shape.dim_size(i); - xla::DynExpr* dim_i_expr = - i < dim_exprs.size() ? dim_exprs[i] : xla::DynExpr::_(dim_i); + xla::DExpr dim_i_expr = + i < dim_exprs.size() ? dim_exprs[i] : xla::DExpr::Const(dim_i); if (stride_i == 0) { return errors::InvalidArgument("strides[", i, "] must be non-zero"); @@ -346,8 +346,8 @@ absl::Status ValidateStridedSliceOp( bool shrink_i = (dense_spec.shrink_axis_mask & (1 << i)); if (dim_i == -1) { processing_shape->AddDim(shrink_i ? 1 : -1); - processing_shape->AddExpression(shrink_i ? xla::DynExpr::_(1) - : xla::DynExpr::_(-1)); + processing_shape->AddExpression(shrink_i ? xla::DExpr::Const(1) + : xla::DExpr::Const(-1)); continue; } @@ -356,9 +356,10 @@ absl::Status ValidateStridedSliceOp( const std::array valid_range = { {stride_i > 0 ? 0 : -1, stride_i > 0 ? dim_i : dim_i - 1}}; - const std::array valid_range_expr = { - {stride_i > 0 ? xla::DynExpr::zero : xla::DynExpr::_(-1), - stride_i > 0 ? dim_i_expr : (*dim_i_expr - *xla::DynExpr::one)->s()}}; + const std::array valid_range_expr = { + {stride_i > 0 ? xla::DExpr::Const(0) : xla::DExpr::Const(-1), + stride_i > 0 ? dim_i_expr + : (dim_i_expr - xla::DExpr::Const(1)).simplify()}}; auto canonical = [stride_i, dim_i, masks, valid_range](int64_t x, int c) { if (masks[c]) { @@ -379,9 +380,9 @@ absl::Status ValidateStridedSliceOp( } else { int64_t x_fwd = x < 0 ? dim_i + x : x; // make negative indices positive - xla::DynExpr* x_expr = xla::DynExpr::_(x); - xla::DynExpr* x_fwd_expr = - x < 0 ? (*dim_i_expr + *x_expr) + xla::DExpr x_expr = xla::DExpr::Const(x); + xla::DExpr x_fwd_expr = + x < 0 ? (dim_i_expr + x_expr) : x_expr; // make negative indices positive return x_fwd < valid_range[0] ? valid_range_expr[0] : x_fwd > valid_range[1] ? valid_range_expr[1] @@ -403,14 +404,14 @@ absl::Status ValidateStridedSliceOp( // and canonical puts these to n-1 and 0, which implies a degenerate // interval. Fortunately, it is now safe to re-create end as begin+1. int64_t x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i; - xla::DynExpr* x_fwd_expr = begin_i < 0 - ? (*dim_i_expr + *(*begin_expr)[i])->s() - : (*begin_expr)[i]; + xla::DExpr x_fwd_expr = begin_i < 0 + ? (dim_i_expr + (*begin_expr)[i]).simplify() + : (*begin_expr)[i]; begin_i = x_fwd; end_i = begin_i + 1; (*begin_expr)[i] = x_fwd_expr; - (*end_expr)[i] = (*(*begin_expr)[i] + *xla::DynExpr::one)->s(); + (*end_expr)[i] = ((*begin_expr)[i] + xla::DExpr::Const(1)).simplify(); if (x_fwd < 0 || x_fwd >= dim_i) { return errors::InvalidArgument( @@ -422,10 +423,10 @@ absl::Status ValidateStridedSliceOp( begin_i = canonical(begin_raw, 0); end_i = canonical(end_raw, 1); if (begin_expr) { - (*begin_expr)[i] = canonical_expr(begin_raw, 0)->s(); + (*begin_expr)[i] = canonical_expr(begin_raw, 0).simplify(); } if (end_expr) { - (*end_expr)[i] = canonical_expr(end_raw, 1)->s(); + (*end_expr)[i] = canonical_expr(end_raw, 1).simplify(); } } // Update optimization values @@ -439,17 +440,17 @@ absl::Status ValidateStridedSliceOp( } // Compute the processing shape (the intermediate Eigen will produce) int64_t interval_length; - xla::DynExpr* interval_length_expr; + xla::DExpr interval_length_expr; bool known_interval = false; if (dense_spec.begin_valid && dense_spec.end_valid) { interval_length = end_i - begin_i; - interval_length_expr = (*(*end_expr)[i] - *(*begin_expr)[i])->s(); + interval_length_expr = ((*end_expr)[i] - (*begin_expr)[i]).simplify(); known_interval = true; } else if (shrink_i) { // The dimension is still known as 1 for the processing_shape, but will be // discarded for the final shape. interval_length = 1; - interval_length_expr = xla::DynExpr::one; + interval_length_expr = xla::DExpr::Const(1); known_interval = true; } else if (begin_and_end_masked) { // Even if we don't have values for begin or end, we do know that this @@ -458,7 +459,7 @@ absl::Status ValidateStridedSliceOp( if (dim_i >= 0) { if (stride_i < 0) { interval_length = -dim_i; - interval_length_expr = (-1 * (*dim_i_expr))->s(); + interval_length_expr = (xla::DExpr::Const(-1) * dim_i_expr).simplify(); } else { interval_length = dim_i; interval_length_expr = dim_i_expr; @@ -468,24 +469,25 @@ absl::Status ValidateStridedSliceOp( } if (known_interval) { int64_t size_i; - xla::DynExpr* size_i_expr; + xla::DExpr size_i_expr; // Hold zero if the interval is degenerate, otherwise account for // remainder if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0))) { size_i = 0; - size_i_expr = xla::DynExpr::zero; + size_i_expr = xla::DExpr::Const(0); } else { size_i = interval_length / stride_i + (interval_length % stride_i != 0 ? 1 : 0); - size_i_expr = *(*interval_length_expr / stride_i) + - *(interval_length % stride_i != 0 ? xla::DynExpr::one - : xla::DynExpr::zero); + size_i_expr = + (interval_length_expr / xla::DExpr::Const(stride_i)) + + (interval_length % stride_i != 0 ? xla::DExpr::Const(1) + : xla::DExpr::Const(0)); } processing_shape->AddDim(size_i); - processing_shape->AddExpression(size_i_expr->s()); + processing_shape->AddExpression(size_i_expr.simplify()); } else { processing_shape->AddDim(-1); - processing_shape->AddExpression(xla::DynExpr::_(-1)); + processing_shape->AddExpression(xla::DExpr::Const(-1)); } } @@ -522,7 +524,7 @@ absl::Status ValidateStridedSliceOp( } } else if (gather_index == kNewAxis) { final_shape->AddDim(1); - final_shape->AddExpression(xla::DynExpr::one); + final_shape->AddExpression(xla::DExpr::Const(1)); if (shape_spec != nullptr) { shape_spec->output_to_sparse_mapping.push_back(-1); shape_spec->output_to_processing_mapping.push_back(-1); @@ -533,6 +535,7 @@ absl::Status ValidateStridedSliceOp( return absl::OkStatus(); } + absl::Status ValidateStridedSliceOp( const Tensor* begin_tensor, const Tensor* end_tensor, const Tensor& strides_tensor, const PartialTensorShape& input_shape, @@ -543,10 +546,9 @@ absl::Status ValidateStridedSliceOp( absl::InlinedVector* begin, absl::InlinedVector* end, absl::InlinedVector* strides, - absl::InlinedVector* begin_expr, - absl::InlinedVector* end_expr, + absl::InlinedVector* begin_expr, + absl::InlinedVector* end_expr, StridedSliceShapeSpec* shape_spec) { - // Validate with PartialTensorShape output PartialTensorShape partial_processing_shape, partial_final_shape; TF_RETURN_IF_ERROR(ValidateStridedSliceOp( begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec, @@ -554,8 +556,6 @@ absl::Status ValidateStridedSliceOp( &partial_processing_shape, &partial_final_shape, is_identity, is_simple_slice, slice_dim0, begin, end, strides, begin_expr, end_expr, shape_spec)); - - // Verify that the output shapes are fully known if (!partial_processing_shape.AsTensorShape(processing_shape) || !partial_final_shape.AsTensorShape(final_shape)) { return errors::Internal("ValidateStridedSliceOp returned partial shapes ", diff --git a/tensorflow/core/util/strided_slice_op.h b/tensorflow/core/util/strided_slice_op.h index 3e55d9bb384481..c7240d15adc6c7 100644 --- a/tensorflow/core/util/strided_slice_op.h +++ b/tensorflow/core/util/strided_slice_op.h @@ -75,8 +75,8 @@ absl::Status ValidateStridedSliceOp( absl::InlinedVector* begin, absl::InlinedVector* end, absl::InlinedVector* strides, - absl::InlinedVector* begin_expr = nullptr, - absl::InlinedVector* end_expr = nullptr, + absl::InlinedVector* begin_expr = nullptr, + absl::InlinedVector* end_expr = nullptr, StridedSliceShapeSpec* shape_spec = nullptr); // Same as above, but the outputs are TensorShape, not PartialTensorShape @@ -90,8 +90,8 @@ absl::Status ValidateStridedSliceOp( absl::InlinedVector* begin, absl::InlinedVector* end, absl::InlinedVector* strides, - absl::InlinedVector* begin_expr = nullptr, - absl::InlinedVector* end_expr = nullptr, + absl::InlinedVector* begin_expr = nullptr, + absl::InlinedVector* end_expr = nullptr, StridedSliceShapeSpec* shape_spec = nullptr); // Simple class for determining if it is possible to broadcast a tensor to a diff --git a/third_party/xla/xla/hlo/builder/lib/broadcast.cc b/third_party/xla/xla/hlo/builder/lib/broadcast.cc index 292a36b623596c..5c7d49c73e1661 100644 --- a/third_party/xla/xla/hlo/builder/lib/broadcast.cc +++ b/third_party/xla/xla/hlo/builder/lib/broadcast.cc @@ -33,7 +33,7 @@ namespace xla { absl::StatusOr BroadcastTo( XlaOp input, absl::Span output_dims, - absl::Span output_exprs) { + absl::Span output_exprs) { XlaBuilder* builder = input.builder(); TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); absl::Span input_dims = input_shape.dimensions(); @@ -48,12 +48,6 @@ absl::StatusOr BroadcastTo( ") must have rank less than or equal to the output shape [", absl::StrJoin(output_dims, ","), "]"); } - - if (!output_exprs.empty() && output_exprs.size() != output_dims.size()) { - return tsl::errors::InvalidArgument( - "output_exprs must be empty or have the same rank as output_dims: ", - output_exprs.size(), " vs ", output_dims.size()); - } std::vector broadcast_dims; std::vector broadcast_shape; @@ -85,23 +79,20 @@ absl::StatusOr BroadcastTo( } } TF_RET_CHECK(input_it == input_dims.rend()); - absl::Span input_exprs = input_shape.expressions(); - std::vector broadcast_exprs; - auto input_dim_et = input_dims.rbegin(); + + absl::Span input_exprs = input_shape.expressions(); + std::vector broadcast_exprs; auto input_et = input_exprs.rbegin(); - auto output_dim_et = output_dims.rbegin(); for (auto output_et = output_exprs.rbegin(); output_et != output_exprs.rend(); - ++output_et, ++output_dim_et) { + ++output_et) { if (input_et != input_exprs.rend()) { - if (*output_dim_et == *input_dim_et || *input_dim_et == 1 || - *(*output_et) == *(*input_et) || - (*input_et)->is_constant() && (*input_et)->get_val() == 1) { + if (**output_et == **input_et || + (input_et->get()->is_constant() && input_et->get()->get_val() == 1)) { broadcast_exprs.push_back(*output_et); - } else if (!(*(*output_et) == *(*input_et))) { + } else if (!(**output_et == **input_et)) { broadcast_exprs.push_back(*input_et); - broadcast_exprs.push_back((**output_et / **input_et)->s()); + broadcast_exprs.push_back((*output_et / *input_et).simplify()); } - ++input_dim_et; ++input_et; } else { broadcast_exprs.push_back(*output_et); diff --git a/third_party/xla/xla/hlo/builder/lib/broadcast.h b/third_party/xla/xla/hlo/builder/lib/broadcast.h index b4ccc51a40d4fd..dd5535f219b971 100644 --- a/third_party/xla/xla/hlo/builder/lib/broadcast.h +++ b/third_party/xla/xla/hlo/builder/lib/broadcast.h @@ -29,7 +29,7 @@ namespace xla { // rules. Supports broadcasting a dimension of size x to size x*y, i.e., tiling. absl::StatusOr BroadcastTo( XlaOp input, absl::Span output_dims, - absl::Span output_exprs = {}); + absl::Span output_exprs = {}); } // namespace xla diff --git a/third_party/xla/xla/hlo/builder/lib/matrix.cc b/third_party/xla/xla/hlo/builder/lib/matrix.cc index 3bd051eb9f582f..077a81c1ce38b4 100644 --- a/third_party/xla/xla/hlo/builder/lib/matrix.cc +++ b/third_party/xla/xla/hlo/builder/lib/matrix.cc @@ -342,7 +342,7 @@ xla::XlaOp EinsumInverseDiagonal(XlaOp x, absl::Span config) { } TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); std::vector broadcast_sizes; - std::vector broadcast_exprs; + std::vector broadcast_exprs; int64_t x_dim = 0; for (auto label = config.begin(); label != config.end(); ++label) { auto first_label = absl::c_find(config, *label); @@ -573,14 +573,14 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, int64_t dot_dim = 0; std::vector new_dims; - std::vector new_exprs; + std::vector new_exprs; new_dims.reserve(output_rank); new_exprs.reserve(output_rank); TF_ASSIGN_OR_RETURN(Shape dot_shape, builder->GetShape(dot)); for (auto d : output_config) { if (is_output_only(d)) { new_dims.push_back(1); - new_exprs.push_back(DynExpr::one); + new_exprs.push_back(DExpr::Const(1)); } else { new_dims.push_back(dot_shape.dimensions(dot_dim)); new_exprs.push_back(dot_shape.expressions(dot_dim)); diff --git a/third_party/xla/xla/hlo/builder/lib/prng.cc b/third_party/xla/xla/hlo/builder/lib/prng.cc index a18fa3cfe54e9c..7bb1c619f824d5 100644 --- a/third_party/xla/xla/hlo/builder/lib/prng.cc +++ b/third_party/xla/xla/hlo/builder/lib/prng.cc @@ -42,7 +42,7 @@ xla::XlaOp ConcatScalars(xla::XlaBuilder* builder, absl::Span scalars) { std::vector vectors; absl::c_transform(scalars, std::back_inserter(vectors), [](xla::XlaOp x) { - return xla::Reshape(x, {1}, {xla::DynExpr::one}); + return xla::Reshape(x, {1}, {xla::DExpr::Const(1)}); }); return ConcatInDim(builder, vectors, 0); } @@ -247,11 +247,11 @@ XlaOp CombineShapePair(absl::Span pair, original_shape.dimensions(shape_pair.split_dim); std::vector reshape_dims(original_shape.dimensions().begin(), original_shape.dimensions().end()); - std::vector reshape_exprs(original_shape.expressions().begin(), - original_shape.expressions().end()); + std::vector reshape_exprs(original_shape.expressions().begin(), + original_shape.expressions().end()); reshape_dims[shape_pair.split_dim] = RoundUpTo(pre_split_size, 2); reshape_exprs[shape_pair.split_dim] = - DynExpr::_(RoundUpTo(pre_split_size, 2)); + DExpr::Const(RoundUpTo(pre_split_size, 2)); result = Reshape(result, reshape_dims, reshape_exprs); if (reshape_dims[shape_pair.split_dim] != pre_split_size) { result = Slice(result, diff --git a/third_party/xla/xla/hlo/builder/lib/slicing.cc b/third_party/xla/xla/hlo/builder/lib/slicing.cc index 479d743a3fbc0c..553936d5790a95 100644 --- a/third_party/xla/xla/hlo/builder/lib/slicing.cc +++ b/third_party/xla/xla/hlo/builder/lib/slicing.cc @@ -168,7 +168,7 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64_t dim, bool sparse) { std::vector index_broadcast_dims; std::vector input_broadcast_dims; std::vector sizes; - std::vector expressions; + std::vector expressions; sizes.reserve(index_shape.dimensions().size()); expressions.reserve(index_shape.expressions().size()); for (int64_t i = 0; i < index_shape.dimensions().size(); ++i) { @@ -236,7 +236,7 @@ XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64_t dim, TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); std::vector index_broadcast_dims; std::vector sizes; - std::vector expressions; + std::vector expressions; const auto rank = index_shape.dimensions().size(); sizes.reserve(rank + 1); expressions.reserve(rank + 1); diff --git a/third_party/xla/xla/hlo/builder/xla_builder.cc b/third_party/xla/xla/hlo/builder/xla_builder.cc index 87fb695a5d5db6..3841206a9ebb8f 100644 --- a/third_party/xla/xla/hlo/builder/xla_builder.cc +++ b/third_party/xla/xla/hlo/builder/xla_builder.cc @@ -38,7 +38,6 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/match.h" -#include "tsl/platform/protobuf.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" @@ -1028,15 +1027,14 @@ absl::StatusOr XlaBuilder::AddBroadcastSequence( std::vector broadcast_dimensions; std::vector reshaped_dimensions; std::vector reshaped_dynamic_dimensions; - std::vector reshaped_expressions; + std::vector reshaped_expressions; for (int i = 0; i < operand_shape->dimensions().size(); i++) { if (operand_shape->dimensions(i) == output_shape.dimensions(i)) { broadcast_dimensions.push_back(i); reshaped_dimensions.push_back(operand_shape->dimensions(i)); reshaped_dynamic_dimensions.push_back( operand_shape->is_dynamic_dimension(i)); - reshaped_expressions.push_back( - operand_shape->expressions(i)); + reshaped_expressions.push_back(operand_shape->expressions(i)); } else { TF_RET_CHECK(operand_shape->dimensions(i) == 1 && operand_shape->is_static_dimension(i)) @@ -1050,7 +1048,8 @@ absl::StatusOr XlaBuilder::AddBroadcastSequence( Shape reshaped_shape = ShapeUtil::MakeShape(operand_shape->element_type(), reshaped_dimensions, - reshaped_dynamic_dimensions, reshaped_expressions); + reshaped_dynamic_dimensions, + reshaped_expressions); // Eliminate the size one dimensions. // The added reshape reduces the rank of the tensor. Hence we cannot directly @@ -1101,11 +1100,11 @@ absl::StatusOr BroadcastToTargetRank( // Update target_size and target_exp with origin sizes and expressions using // broadcast_dimensions absl::Span target_dimensions = target_shape.dimensions(); - absl::Span target_expressions = target_shape.expressions(); + absl::Span target_expressions = target_shape.expressions(); std::vector target_size{target_dimensions.begin(), target_dimensions.end()}; - std::vector target_exp{target_expressions.begin(), - target_expressions.end()}; + std::vector target_exp(target_expressions.begin(), + target_expressions.end()); for (int64_t origin_dim = 0; origin_dim < origin_rank; origin_dim++) { int64_t target_dim = broadcast_dimensions[origin_dim]; target_size[target_dim] = origin_shape.dimensions(origin_dim); @@ -1129,7 +1128,7 @@ absl::StatusOr> ExtractDimensionSizesAndPadOnesToLeft( ? ConstantR1( /*builder=*/builder, /*values=*/{static_cast(op_shape->dimensions(i))}) - : Reshape(GetDimensionSize(op, i), {1}, {xla::DynExpr::one})); + : Reshape(GetDimensionSize(op, i), {1}, {xla::DExpr::Const(1)})); } return op_dims; } @@ -1152,7 +1151,7 @@ absl::StatusOr BroadcastScalarToOutputShapeWithUnbounded( /*builder=*/builder, /*values=*/{static_cast(output_shape.dimensions(i))}) : Reshape(GetDimensionSize(output, i), {1}, - {xla::DynExpr::one}); + {xla::DExpr::Const(1)}); } return MhloDynamicBroadcastInDim( scalar, /*output_dimensions=*/ConcatInDim(builder, output_sizes, 0), {}, @@ -1536,20 +1535,12 @@ XlaOp XlaBuilder::Parameter( XlaOp XlaBuilder::Broadcast(XlaOp operand, absl::Span broadcast_sizes, - absl::Span broadcast_exprs) { + absl::Span broadcast_exprs) { return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(const Shape& shape, ShapeInference::InferBroadcastShape( *operand_shape, broadcast_sizes, broadcast_exprs)); - - // The client-level broadcast op just appends dimensions on the left (adds - // lowest numbered dimensions). The HLO broadcast instruction is more - // flexible and can add new dimensions anywhere. The instruction's - // dimensions field maps operand dimensions to dimensions in the broadcast - // output, so to append dimensions on the left the instruction's dimensions - // should just be the n highest dimension numbers of the output shape where - // n is the number of input dimensions. const int64_t operand_rank = operand_shape->dimensions().size(); std::vector dimensions(operand_rank); for (int i = 0; i < operand_rank; ++i) { @@ -1562,11 +1553,9 @@ XlaOp XlaBuilder::Broadcast(XlaOp operand, XlaOp XlaBuilder::BroadcastInDim( XlaOp operand, absl::Span out_dim_size, absl::Span broadcast_dimensions, - absl::Span out_dim_exp) { + absl::Span out_dim_exp) { return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); - // Output shape, in the case of degenerate broadcast, the out_dim_size is - // not necessarily the same as the dimension sizes of the output shape. TF_ASSIGN_OR_RETURN(auto output_shape, ShapeUtil::MakeValidatedShape( operand_shape->element_type(), out_dim_size, out_dim_exp)); @@ -1596,16 +1585,13 @@ XlaOp XlaBuilder::BroadcastInDim( .status()); std::vector in_dim_size(out_dim_size.begin(), out_dim_size.end()); std::vector in_dim_dynamic(out_dim_size.size(), false); - std::vector in_expressions(out_dim_exp.begin(), - out_dim_exp.end()); + std::vector in_expressions(out_dim_exp.begin(), out_dim_exp.end()); - // If out_dim_exp is empty just make expressions out of the static - // dimensions. if (out_dim_exp.empty()) { in_expressions.reserve(out_dim_size.size()); std::transform(out_dim_size.begin(), out_dim_size.end(), std::back_inserter(in_expressions), - [](int d) { return DynExpr::_(d); }); + [](int d) { return DExpr::Const(d); }); } for (int i = 0; i < broadcast_rank; i++) { @@ -1615,8 +1601,7 @@ XlaOp XlaBuilder::BroadcastInDim( : operand_shape->dimensions(i); in_dim_dynamic[broadcast_dimensions[i]] = operand_shape->is_bounded_dynamic_dimension(i); - in_expressions[broadcast_dimensions[i]] = - operand_shape->expressions(i); + in_expressions[broadcast_dimensions[i]] = operand_shape->expressions(i); } const auto& in_dim_shape = ShapeUtil::MakeShape(operand_shape->element_type(), in_dim_size, @@ -1624,13 +1609,9 @@ XlaOp XlaBuilder::BroadcastInDim( TF_ASSIGN_OR_RETURN( XlaOp in_dim_broadcast, InDimBroadcast(in_dim_shape, operand, broadcast_dimensions)); - - // If broadcast is not degenerate, return broadcasted result. if (ShapeUtil::Equal(in_dim_shape, output_shape)) { return in_dim_broadcast; } - - // Otherwise handle degenerate broadcast case. return AddBroadcastSequence(output_shape, in_dim_broadcast); }); } @@ -1666,15 +1647,16 @@ XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span start_indices, XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span start_indices, absl::Span limit_indices, - absl::Span start_exprs, - absl::Span limit_exprs, + absl::Span start_exprs, + absl::Span limit_exprs, absl::Span strides) { return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN( - Shape shape, ShapeInference::InferSliceShape( - *operand_shape, start_indices, limit_indices, strides, - start_exprs, limit_exprs)); + Shape shape, ShapeInference::InferSliceShape(*operand_shape, + start_indices, + limit_indices, strides, + start_exprs, limit_exprs)); return SliceInternal(shape, operand, start_indices, limit_indices, strides); }); } @@ -1711,18 +1693,17 @@ XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64_t start_index, } XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64_t start_index, - int64_t limit_index, DynExpr* start_expr, - DynExpr* limit_expr, int64_t stride, + int64_t limit_index, const DExpr& start_expr, + const DExpr& limit_expr, int64_t stride, int64_t dimno) { return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand)); std::vector starts(shape->dimensions().size(), 0); std::vector limits(shape->dimensions().begin(), shape->dimensions().end()); - std::vector start_exprs(shape->dimensions().size(), - DynExpr::zero); - std::vector limit_exprs(shape->expressions().begin(), - shape->expressions().end()); + std::vector start_exprs(shape->dimensions().size(), DExpr::Const(0)); + std::vector limit_exprs(shape->expressions().begin(), + shape->expressions().end()); std::vector strides(shape->dimensions().size(), 1); starts[dimno] = start_index; limits[dimno] = limit_index; @@ -1736,7 +1717,7 @@ XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64_t start_index, XlaOp XlaBuilder::DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes, - absl::Span slice_exprs) { + absl::Span slice_exprs) { return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); std::vector start_indices_shape_ptrs; @@ -1869,7 +1850,7 @@ XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span dimensions, } XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span dimensions, - absl::Span expressions, + absl::Span expressions, int64_t inferred_dimension) { return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); @@ -1892,13 +1873,11 @@ XlaOp XlaBuilder::DynamicReshape(XlaOp operand, absl::Span dim_sizes, absl::Span new_size_bounds, const std::vector& dims_are_dynamic, - absl::Span expressions) { + absl::Span expressions) { return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); std::vector dim_size_shape_ptrs; - TF_ASSIGN_OR_RETURN(const auto& dim_size_shapes, - GetOperandShapes(dim_sizes)); - + TF_ASSIGN_OR_RETURN(const auto& dim_size_shapes, GetOperandShapes(dim_sizes)); absl::c_transform(dim_size_shapes, std::back_inserter(dim_size_shape_ptrs), [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN( @@ -1947,7 +1926,7 @@ XlaOp XlaBuilder::Collapse(XlaOp operand, VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ","); std::vector new_sizes; - std::vector new_exprs; + std::vector new_exprs; for (int i = 0; i < original_shape->dimensions().size(); ++i) { if (i <= dimensions.front() || i > dimensions.back()) { new_sizes.push_back(original_shape->dimensions(i)); @@ -1955,12 +1934,11 @@ XlaOp XlaBuilder::Collapse(XlaOp operand, } else { new_sizes.back() *= original_shape->dimensions(i); new_exprs.back() = - *(new_exprs.back()) * *(original_shape->expressions(i)); + new_exprs.back() * original_shape->expressions(i); } } VLOG(3) << "new sizes: [" << absl::StrJoin(new_sizes, ",") << "]"; - return Reshape(operand, new_sizes, new_exprs); }); } @@ -3998,14 +3976,14 @@ XlaOp XlaBuilder::AllToAllArray( return all_to_all; } DimensionVector sizes; - std::vector expressions; + std::vector expressions; const bool is_unbounded = operand_shape->is_unbounded_dynamic(); std::vector dynamic_sizes; auto GetR1DimensionSizeOrConstant = [&](XlaOp operand, int64_t dimension) -> XlaOp { if (operand_shape->is_unbounded_dynamic_dimension(dimension)) { return Reshape(GetDimensionSize(operand, dimension), {1}, - {DynExpr::one}); + {DExpr::Const(1)}); } return ConstantR1( this, {static_cast(operand_shape->dimensions(dimension))}); @@ -4022,12 +4000,12 @@ XlaOp XlaBuilder::AllToAllArray( continue; } sizes.push_back(split_count); - expressions.push_back(DynExpr::_(split_count)); + expressions.push_back(DExpr::Const(split_count)); sizes.push_back(operand_shape->is_unbounded_dynamic_dimension(i) ? Shape::kUnboundedSize : operand_shape->dimensions(i) / split_count); expressions.push_back( - (*operand_shape->expressions(i) / split_count)->s()); + (operand_shape->expressions(i) / split_count).simplify()); if (is_unbounded) { dynamic_sizes.push_back(r1_split_count); @@ -4581,21 +4559,14 @@ XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64_t dimension) { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape( *operand_shape, dimension)); - DynExpr* dim_expr = operand_shape->expressions(dimension); - if (dim_expr != nullptr && dim_expr->is_dynamic()) { + const DExpr& dim_expr = operand_shape->expressions(dimension); + if (dim_expr && dim_expr->is_dynamic()) { + // Carry the padded static dimension as the operand value so value + // inference can treat it as an upper bound for GetExpressionValue. XlaOp dim_bound = ConstantR0(this, operand_shape->dimensions(dimension)); - ExpressionProto expr_proto; - dim_expr->to_proto(&expr_proto); - std::string expr_textproto = - tsl::LegacyUnredactedShortDebugString(expr_proto); - VLOG(1) << "GetDimensionSize: expr_textproto is " << expr_textproto - << " ShortDebugString is " << expr_proto.ShortDebugString(); - TF_RETURN_IF_ERROR(SetInstructionFrontendAttribute( - dim_bound, "dynamic_constant_index", "0")); - TF_RETURN_IF_ERROR(SetInstructionFrontendAttribute( - dim_bound, "dynamic_constant_expr", expr_textproto)); - return dim_bound; + XlaOp expr_carrier = Broadcast(dim_bound, {1}, {dim_expr}); + return GetExpressionValue(expr_carrier); } // Calling GetDimensionSize on a static dimension returns a constant // instruction. @@ -5095,7 +5066,7 @@ XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) { } XlaOp Broadcast(const XlaOp operand, absl::Span broadcast_sizes, - absl::Span broadcast_exprs) { + absl::Span broadcast_exprs) { return operand.builder()->Broadcast(operand, broadcast_sizes, broadcast_exprs); } @@ -5103,7 +5074,7 @@ XlaOp Broadcast(const XlaOp operand, absl::Span broadcast_sizes, XlaOp BroadcastInDim(const XlaOp operand, absl::Span out_dim_size, absl::Span broadcast_dimensions, - absl::Span out_dim_exp) { + absl::Span out_dim_exp) { return operand.builder()->BroadcastInDim(operand, out_dim_size, broadcast_dimensions, out_dim_exp); } @@ -5141,7 +5112,7 @@ XlaOp Reshape(const XlaOp operand, absl::Span dimensions) { } XlaOp Reshape(const XlaOp operand, absl::Span dimensions, - absl::Span expressions) { + absl::Span expressions) { return operand.builder()->Reshape(operand, dimensions, expressions); } @@ -5152,14 +5123,23 @@ XlaOp Reshape(const Shape& shape, XlaOp operand) { XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, absl::Span new_size_bounds, const std::vector& dims_are_dynamic, - absl::Span expressions) { + absl::Span expressions) { return operand.builder()->DynamicReshape(operand, dim_sizes, new_size_bounds, dims_are_dynamic, expressions); } +XlaOp Slice(XlaOp operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span start_exprs, + absl::Span limit_exprs, + absl::Span strides) { + return operand.builder()->Slice(operand, start_indices, limit_indices, + start_exprs, limit_exprs, strides); +} + XlaOp ReshapeWithInferredDimension(XlaOp operand, absl::Span new_sizes, - absl::Span new_exprs, + absl::Span new_exprs, int64_t inferred_dimension) { return operand.builder()->Reshape(operand, new_sizes, new_exprs, inferred_dimension); @@ -5176,15 +5156,6 @@ XlaOp Slice(const XlaOp operand, absl::Span start_indices, strides); } -XlaOp Slice(const XlaOp operand, absl::Span start_indices, - absl::Span limit_indices, - absl::Span start_exprs, - absl::Span limit_exprs, - absl::Span strides) { - return operand.builder()->Slice(operand, start_indices, limit_indices, - start_exprs, limit_exprs, strides); -} - XlaOp SliceInDim(const XlaOp operand, int64_t start_index, int64_t limit_index, int64_t stride, int64_t dimno) { return operand.builder()->SliceInDim(operand, start_index, limit_index, @@ -5192,15 +5163,15 @@ XlaOp SliceInDim(const XlaOp operand, int64_t start_index, int64_t limit_index, } XlaOp SliceInDim(const XlaOp operand, int64_t start_index, int64_t limit_index, - DynExpr* start_expr, DynExpr* limit_expr, int64_t stride, - int64_t dimno) { + const DExpr& start_expr, const DExpr& limit_expr, + int64_t stride, int64_t dimno) { return operand.builder()->SliceInDim(operand, start_index, limit_index, start_expr, limit_expr, stride, dimno); } XlaOp DynamicSlice(const XlaOp operand, absl::Span start_indices, absl::Span slice_sizes, - absl::Span slice_exprs) { + absl::Span slice_exprs) { return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes, slice_exprs); } diff --git a/third_party/xla/xla/hlo/builder/xla_builder.h b/third_party/xla/xla/hlo/builder/xla_builder.h index af0202104e2652..7abdf120e4c286 100644 --- a/third_party/xla/xla/hlo/builder/xla_builder.h +++ b/third_party/xla/xla/hlo/builder/xla_builder.h @@ -520,11 +520,11 @@ class XlaBuilder { virtual XlaOp ConstantLiteral(const LiteralSlice& literal); XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes, - absl::Span broadcast_exprs = {}); + absl::Span broadcast_exprs = {}); XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, absl::Span broadcast_dimensions, - absl::Span out_dim_exp = {}); + absl::Span out_dim_exp = {}); // This is an experimental API for creating the mhlo.dynamic_broadcast_in_dim // op from the XlaBuilder. This is only intended for export to MHLO or @@ -548,7 +548,7 @@ class XlaBuilder { int64_t inferred_dimension = -1); XlaOp Reshape(XlaOp operand, absl::Span dimensions, - absl::Span expressions, + absl::Span expressions, int64_t inferred_dimension = -1); XlaOp Reshape(const Shape& shape, XlaOp operand, @@ -557,7 +557,7 @@ class XlaBuilder { XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, absl::Span new_size_bounds, const std::vector& dims_are_dynamic, - absl::Span expressions = {}); + absl::Span expressions = {}); XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, const Shape& shape); @@ -570,8 +570,8 @@ class XlaBuilder { XlaOp Slice(XlaOp operand, absl::Span start_indices, absl::Span limit_indices, - absl::Span start_exprs, - absl::Span limit_exprs, + absl::Span start_exprs, + absl::Span limit_exprs, absl::Span strides); virtual absl::StatusOr SliceInternal( @@ -583,12 +583,13 @@ class XlaBuilder { int64_t limit_index, int64_t stride, int64_t dimno); virtual XlaOp SliceInDim(XlaOp operand, int64_t start_index, - int64_t limit_index, DynExpr* start_expr, - DynExpr* limit_expr, int64_t stride, int64_t dimno); + int64_t limit_index, const DExpr& start_expr, + const DExpr& limit_expr, int64_t stride, + int64_t dimno); XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes, - absl::Span slice_exprs = {}); + absl::Span slice_exprs = {}); virtual absl::StatusOr DynamicSliceInternal( const Shape& shape, XlaOp operand, absl::Span start_indices, absl::Span slice_sizes); @@ -1264,12 +1265,12 @@ class XlaBuilder { friend XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes, - absl::Span broadcast_expressions); + absl::Span broadcast_expressions); friend XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, absl::Span broadcast_dimensions, - absl::Span out_dim_exp); + absl::Span out_dim_exp); friend XlaOp MhloDynamicBroadcastInDim( XlaOp operand, XlaOp output_dimensions, @@ -1287,21 +1288,21 @@ class XlaBuilder { friend XlaOp Reshape(XlaOp operand, absl::Span dimensions); friend XlaOp Reshape(XlaOp operand, absl::Span dimensions, - absl::Span expressions); + absl::Span expressions); friend XlaOp Reshape(const Shape& shape, XlaOp operand); friend XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, absl::Span new_size_bounds, const std::vector& dims_are_dynamic, - absl::Span expressions); + absl::Span expressions); friend XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, const Shape& shape); friend XlaOp ReshapeWithInferredDimension( XlaOp operand, absl::Span new_sizes, - absl::Span new_exprs, int64_t inferred_dimension); + absl::Span new_exprs, int64_t inferred_dimension); friend XlaOp Collapse(XlaOp operand, absl::Span dimensions); @@ -1311,21 +1312,22 @@ class XlaBuilder { friend XlaOp Slice(XlaOp operand, absl::Span start_indices, absl::Span limit_indices, - absl::Span start_exprs, - absl::Span limit_exprs, + absl::Span start_exprs, + absl::Span limit_exprs, absl::Span strides); friend XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index, int64_t stride, int64_t dimno); friend XlaOp SliceInDim(XlaOp operand, int64_t start_index, - int64_t limit_index, DynExpr* start_expr, - DynExpr* limit_expr, int64_t stride, int64_t dimno); + int64_t limit_index, const DExpr& start_expr, + const DExpr& limit_expr, int64_t stride, + int64_t dimno); friend XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes, - absl::Span slice_exprs); + absl::Span slice_exprs); friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices); @@ -2012,7 +2014,7 @@ XlaOp ConstantR1(XlaBuilder* builder, int64_t length, NativeT value); // // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes, - absl::Span broadcast_exprs = {}); + absl::Span broadcast_exprs = {}); // This op broadcasts the `operand` to an output with the given `shape`. // `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the @@ -2031,7 +2033,7 @@ XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes, // {2 , 2}} XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, absl::Span broadcast_dimensions, - absl::Span out_dim_exp = {}); + absl::Span out_dim_exp = {}); // This is an experimental API for creating the mhlo.dynamic_broadcast_in_dim // op from the XlaBuilder. This is only intended for export to MHLO or @@ -2077,7 +2079,7 @@ XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno, XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, absl::Span new_size_bounds, const std::vector& dims_are_dynamic, - absl::Span expressions); + absl::Span expressions = {}); // This is an experimental API for creating the mhlo.dynamic_reshape op from the // XlaBuilder. This is only intended for export to MHLO or StableHLO, and cannot @@ -2090,7 +2092,7 @@ XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, const Shape& shape); XlaOp Reshape(XlaOp operand, absl::Span dimensions); XlaOp Reshape(XlaOp operand, absl::Span dimensions, - absl::Span expressions); + absl::Span expressions); // Enqueues a Reshape op that uses an explicit target shape. XlaOp Reshape(const Shape& shape, XlaOp operand); @@ -2101,7 +2103,7 @@ XlaOp Reshape(const Shape& shape, XlaOp operand); // is a dynamic dimension in the output, it must be the inferred dimension. XlaOp ReshapeWithInferredDimension(XlaOp operand, absl::Span new_sizes, - absl::Span new_exprs, + absl::Span new_exprs, int64_t inferred_dimension); // Wrapper for Reshape. @@ -2141,8 +2143,8 @@ XlaOp Slice(XlaOp operand, absl::Span start_indices, XlaOp Slice(XlaOp operand, absl::Span start_indices, absl::Span limit_indices, - absl::Span start_exprs, - absl::Span limit_exprs, + absl::Span start_exprs, + absl::Span limit_exprs, absl::Span strides); // Enqueues a slice operation in a given dimension, taking all other @@ -2155,7 +2157,7 @@ XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index, int64_t stride, int64_t dimno); XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index, - DynExpr* start_expr, DynExpr* limit_expr, + const DExpr& start_expr, const DExpr& limit_expr, int64_t stride, int64_t dimno); // Enqueues a slice operation onto the computation that slices the 'operand' @@ -2170,7 +2172,7 @@ XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index, // prevent dynamic start indices from generating out-of-bound array accesses. XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes, - absl::Span slice_exprs = {}); + absl::Span slice_exprs = {}); // Enqueues a dynamic update slice operation onto the computation, which // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. diff --git a/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc index 02dd7813d033de..ecdc47c0877542 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc @@ -80,9 +80,9 @@ absl::StatusOr BitcastDtypesExpander::ExpandInstruction( broadcasted_input_shape.push_back(input_bit_width / output_bit_width); reshaped_input_shape.push_back(1); int64_t output_bit_width_mask = (int64_t{1} << output_bit_width) - 1; - std::vector reshaped_input_exprs( + std::vector reshaped_input_exprs( from_shape.expressions().begin(), from_shape.expressions().end()); - reshaped_input_exprs.push_back(DynExpr::_(1)); + reshaped_input_exprs.push_back(DExpr::Const(1)); TF_ASSIGN_OR_RETURN( input, BroadcastTo( Reshape(input, reshaped_input_shape, reshaped_input_exprs), diff --git a/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc b/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc index a7d62392dec79a..cd74b9406c3639 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc @@ -81,16 +81,16 @@ absl::Status CanonicalizeDot(HloDotInstruction* original_dot) { int64_t lhs_contracting_size = 1; bool lhs_contracting_dynamic = false; int64_t lhs_contracting_multiplier_accu = 1; - DynExpr* lhs_contracting_expression = DynExpr::one; + DExpr lhs_contracting_expression = DExpr::Const(1); int64_t lhs_non_contracting_size = 1; bool lhs_non_contracting_dynamic = false; int64_t lhs_non_contracting_multiplier_accu = 1; - DynExpr* lhs_non_contracting_expression = DynExpr::one; + DExpr lhs_non_contracting_expression = DExpr::Const(1); std::vector batch_dim_sizes; batch_dim_sizes.reserve(num_batch_dims); std::vector batch_dynamic_dims; batch_dynamic_dims.reserve(num_batch_dims); - std::vector batch_expressions; + std::vector batch_expressions; batch_expressions.reserve(num_batch_dims); bool lhs_contracting_is_static = true; @@ -101,7 +101,7 @@ absl::Status CanonicalizeDot(HloDotInstruction* original_dot) { lhs_contracting_size *= lhs_shape.dimensions(i); lhs_contracting_dynamic |= lhs_shape.is_dynamic_dimension(i); lhs_contracting_expression = - (*lhs_contracting_expression) * (*lhs_shape.expressions(i)); + lhs_contracting_expression * lhs_shape.expressions(i); } else if (absl::c_linear_search(original_dnums.lhs_batch_dimensions(), i)) { batch_dim_sizes.push_back(lhs_shape.dimensions(i)); @@ -112,7 +112,7 @@ absl::Status CanonicalizeDot(HloDotInstruction* original_dot) { lhs_non_contracting_size *= lhs_shape.dimensions(i); lhs_non_contracting_dynamic |= lhs_shape.is_dynamic_dimension(i); lhs_non_contracting_expression = - (*lhs_non_contracting_expression) * (*lhs_shape.expressions(i)); + lhs_non_contracting_expression * lhs_shape.expressions(i); } } @@ -139,15 +139,15 @@ absl::Status CanonicalizeDot(HloDotInstruction* original_dot) { std::vector lhs_reshape_dims = batch_dim_sizes; std::vector lhs_reshape_dynamic_dims = batch_dynamic_dims; - std::vector lhs_reshape_expressions = batch_expressions; + std::vector lhs_reshape_expressions = batch_expressions; if (lhs_non_contracting_size > 1) { lhs_reshape_dims.push_back(lhs_non_contracting_size); lhs_reshape_dynamic_dims.push_back(lhs_non_contracting_dynamic); - lhs_reshape_expressions.push_back(lhs_non_contracting_expression->s()); + lhs_reshape_expressions.push_back(lhs_non_contracting_expression.simplify()); } lhs_reshape_dims.push_back(lhs_contracting_size); lhs_reshape_dynamic_dims.push_back(lhs_contracting_dynamic); - lhs_reshape_expressions.push_back(lhs_contracting_expression->s()); + lhs_reshape_expressions.push_back(lhs_contracting_expression.simplify()); // Reshape the contracting and non-contracting dimensions together. auto sh_lhs = ShapeUtil::MakeShape(lhs_shape.element_type(), lhs_reshape_dims, lhs_reshape_dynamic_dims, @@ -165,11 +165,11 @@ absl::Status CanonicalizeDot(HloDotInstruction* original_dot) { int64_t rhs_non_contracting_size = 1; bool rhs_non_contracting_dynamic = false; int64_t rhs_non_contracting_multiplier_accu = 1; - DynExpr* rhs_non_contracting_expression = DynExpr::one; + DExpr rhs_non_contracting_expression = DExpr::Const(1); int64_t rhs_contracting_size = 1; bool rhs_contracting_dynamic = false; int64_t rhs_contracting_multiplier_accu = 1; - DynExpr* rhs_contracting_expression = DynExpr::one; + DExpr rhs_contracting_expression = DExpr::Const(1); bool rhs_contracting_is_static = true; bool rhs_non_contracting_is_static = true; @@ -179,14 +179,14 @@ absl::Status CanonicalizeDot(HloDotInstruction* original_dot) { rhs_contracting_size *= rhs_shape.dimensions(i); rhs_contracting_dynamic |= rhs_shape.is_dynamic_dimension(i); rhs_contracting_expression = - (*rhs_contracting_expression) * (*rhs_shape.expressions(i)); + rhs_contracting_expression * rhs_shape.expressions(i); } else if (!absl::c_linear_search(original_dnums.rhs_batch_dimensions(), i)) { rhs_non_contracting_dims.push_back(i); rhs_non_contracting_size *= rhs_shape.dimensions(i); rhs_non_contracting_dynamic |= rhs_shape.is_dynamic_dimension(i); rhs_non_contracting_expression = - (*rhs_non_contracting_expression) * (*rhs_shape.expressions(i)); + rhs_non_contracting_expression * rhs_shape.expressions(i); } } @@ -215,12 +215,12 @@ absl::Status CanonicalizeDot(HloDotInstruction* original_dot) { rhs_reshape_dims.push_back(rhs_contracting_size); std::vector rhs_reshape_dynamic_dims = batch_dynamic_dims; rhs_reshape_dynamic_dims.push_back(rhs_contracting_dynamic); - std::vector rhs_reshape_expressions = batch_expressions; - rhs_reshape_expressions.push_back(rhs_contracting_expression->s()); + std::vector rhs_reshape_expressions = batch_expressions; + rhs_reshape_expressions.push_back(rhs_contracting_expression.simplify()); if (rhs_non_contracting_size > 1) { rhs_reshape_dims.push_back(rhs_non_contracting_size); rhs_reshape_dynamic_dims.push_back(rhs_non_contracting_dynamic); - rhs_reshape_expressions.push_back(rhs_non_contracting_expression->s()); + rhs_reshape_expressions.push_back(rhs_non_contracting_expression.simplify()); } // Reshape the contracting and non-contracting dimensions together. auto sh_rhs = ShapeUtil::MakeShape(rhs_shape.element_type(), rhs_reshape_dims, @@ -234,16 +234,16 @@ absl::Status CanonicalizeDot(HloDotInstruction* original_dot) { std::vector dot_dims = batch_dim_sizes; std::vector dot_dynamic_dims = batch_dynamic_dims; - std::vector dot_expressions = batch_expressions; + std::vector dot_expressions = batch_expressions; if (lhs_non_contracting_size > 1) { dot_dims.push_back(lhs_non_contracting_size); dot_dynamic_dims.push_back(lhs_non_contracting_dynamic); - dot_expressions.push_back(lhs_non_contracting_expression->s()); + dot_expressions.push_back(lhs_non_contracting_expression.simplify()); } if (rhs_non_contracting_size > 1) { dot_dims.push_back(rhs_non_contracting_size); dot_dynamic_dims.push_back(rhs_non_contracting_dynamic); - dot_expressions.push_back(rhs_non_contracting_expression->s()); + dot_expressions.push_back(rhs_non_contracting_expression.simplify()); } DotDimensionNumbers dot_dnums; diff --git a/third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc index 27b654587ba01a..17775e0f132dfe 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc @@ -58,9 +58,9 @@ std::vector ConcatVectors(absl::Span xs, return output; } -std::vector ConcatEVectors(absl::Span xs, - absl::Span ys) { - std::vector output; +std::vector ConcatEVectors(absl::Span xs, + absl::Span ys) { + std::vector output; output.reserve(xs.size() + ys.size()); std::copy(xs.begin(), xs.end(), std::back_inserter(output)); std::copy(ys.begin(), ys.end(), std::back_inserter(output)); @@ -229,12 +229,12 @@ absl::StatusOr QrExpander::QrBlock( const int64_t m = ShapeUtil::GetDimension(a_shape, -2); const int64_t n = ShapeUtil::GetDimension(a_shape, -1); - DynExpr* m_exp = ShapeUtil::GetExpression(a_shape, -2); - DynExpr* n_exp = ShapeUtil::GetExpression(a_shape, -1); + const DExpr& m_exp = ShapeUtil::GetExpression(a_shape, -2); + const DExpr& n_exp = ShapeUtil::GetExpression(a_shape, -1); const int64_t num_batch_dims = num_dims - 2; std::vector batch_dims(num_batch_dims); - std::vector batch_exprs(num_batch_dims); + std::vector batch_exprs(num_batch_dims); for (int i = 0; i < num_batch_dims; ++i) { batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); batch_exprs[i] = ShapeUtil::GetExpression(a_shape, i); @@ -261,10 +261,10 @@ absl::StatusOr QrExpander::QrBlock( minor_dim + 1); std::vector shape = batch_dims; - std::vector exprs = batch_exprs; + std::vector exprs = batch_exprs; shape.push_back(1); shape.push_back(m); - exprs.push_back(DynExpr::one); + exprs.push_back(DExpr::Const(1)); exprs.push_back(m_exp); auto v_broadcast = Reshape(v, shape, exprs); // a[:, j+1:] -= np.conj(tau) * (v[:, np.newaxis] @ @@ -280,7 +280,7 @@ absl::StatusOr QrExpander::QrBlock( // a[j, j] = beta // a[j+1:,j] = v[j+1:] auto iota = - Reshape(Iota(a.builder(), S32, m), {m, 1}, {m_exp, DynExpr::one}); + Reshape(Iota(a.builder(), S32, m), {m, 1}, {m_exp, DExpr::Const(1)}); auto predecessor_mask = ConvertElementType(Lt(iota, j), type); auto mask = Broadcast(ConvertElementType(Eq(iota, j), type), std::vector(batch_dims.size(), 1)); diff --git a/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.cc index 2a6f428b65aa82..2e84d07f076078 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.cc @@ -84,7 +84,7 @@ RngBitGeneratorExpander::GetGeneratorComputation(const Shape& data_shape, } XlaOp final_state = ConcatInDim( - &builder, {Reshape(key_op, {1}, {DynExpr::one}), output.state}, 0); + &builder, {Reshape(key_op, {1}, {DExpr::Const(1)}), output.state}, 0); Tuple(&builder, {final_state, output.value}); TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); TF_ASSIGN_OR_RETURN(HloComputation * new_computation, diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 90bf284e206a77..6e50c950e1d7e4 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -2673,7 +2673,7 @@ LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) { SmallVector dimSizes; SmallVector newSizeBounds; std::vector dimsAreDynamic; - std::vector dimExpressions; + std::vector dimExpressions; for (auto i = 0; i < resultType.getRank(); ++i) { auto runtimeSizeX1 = xla::Slice(outputShape, {i}, {i + 1}, {1}); @@ -2685,7 +2685,7 @@ LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) { return op->emitOpError() << "unbounded dynamism is not supported"; newSizeBounds.push_back(hlo::isStaticDimSize(dimSize) ? dimSize : dimBound); dimsAreDynamic.push_back(!hlo::isStaticDimSize(dimSize)); - dimExpressions.push_back(xla::DynExpr::_(-40)); // Don't know. + dimExpressions.push_back(xla::DExpr::Unknown(40)); } value_map[op] = xla::DynamicReshape(operand, dimSizes, newSizeBounds, dimsAreDynamic, dimExpressions); @@ -3008,7 +3008,7 @@ LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) { SmallVector dimSizes; SmallVector newSizeBounds; std::vector dimsAreDynamic; - std::vector dimExpressions; + std::vector dimExpressions; for (auto i = 0; i < resultType.getRank(); ++i) { auto runtimeSizeX1 = xla::Slice(outputShape, {i}, {i + 1}, {1}); dimSizes.push_back(xla::Reshape(runtimeSizeX1, {})); @@ -3019,7 +3019,7 @@ LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) { return op->emitOpError() << "unbounded dynamism is not supported"; newSizeBounds.push_back(hlo::isStaticDimSize(dimSize) ? dimSize : dimBound); dimsAreDynamic.push_back(!hlo::isStaticDimSize(dimSize)); - dimExpressions.push_back(xla::DynExpr::_(-50)); // Don't know + dimExpressions.push_back(xla::DExpr::Unknown(50)); } value_map[op] = xla::DynamicReshape(operand, dimSizes, newSizeBounds, dimsAreDynamic, diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc index 9369b89d77ec61..a06d517914bdf6 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc @@ -150,7 +150,7 @@ Shape TypeToShape(mlir::Type type) { llvm::SmallVector shape(rank, mlir::ShapedType::kDynamic); std::vector is_dynamic(rank, false); - std::vector expressions(rank, DynExpr::_(-60)); + std::vector expressions(rank, DExpr::Unknown(60)); for (int64_t dim = 0; dim < rank; ++dim) { int64_t size = t.getDimSize(dim); if (size == ShapedType::kDynamic) { diff --git a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc index 54c34104f4b7b0..bd26c1c87621cd 100644 --- a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc +++ b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc @@ -59,23 +59,6 @@ bool IsNonComplexNonBatchedMatrixVectorDot(const HloInstruction* hlo) { hlo->dot_dimension_numbers().lhs_batch_dimensions_size() == 0; } -bool HasDynamicDimensions(const Shape& shape) { - for (int64_t i = 0; i < shape.dimensions().size(); ++i) { - if (shape.is_dynamic_dimension(i) || - shape.expressions(i)->is_dynamic()) { - return true; - } - } - return false; -} - -bool IsDynamicDot(const HloInstruction* hlo) { - return hlo->opcode() == HloOpcode::kDot && - (HasDynamicDimensions(hlo->shape()) || - HasDynamicDimensions(hlo->operand(0)->shape()) || - HasDynamicDimensions(hlo->operand(1)->shape())); -} - bool HasExactlyOneUse(const HloInstruction& hlo_instr) { return hlo_instr.user_count() == 1 && absl::c_count(hlo_instr.users().front()->operands(), &hlo_instr) == 1; @@ -159,10 +142,6 @@ FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return FusionDecision::Forbid("Don't fuse large constants."); } - if (IsDynamicDot(producer) || IsDynamicDot(consumer)) { - return FusionDecision::Forbid("Do not fuse dynamic dots on CPU."); - } - if (CanBeOutputFused(producer, consumer)) { VLOG(2) << "Fusion OK: Can create output fusion."; return FusionDecision::Allow(); diff --git a/third_party/xla/xla/service/cpu/dot_op_emitter.cc b/third_party/xla/xla/service/cpu/dot_op_emitter.cc index e201b87a4a2a6a..9ceac786ba26a6 100644 --- a/third_party/xla/xla/service/cpu/dot_op_emitter.cc +++ b/third_party/xla/xla/service/cpu/dot_op_emitter.cc @@ -165,20 +165,17 @@ bool HasDynamicMatmulDims(const DotInfo& dot_info) { const Shape& rhs_shape = dot_info.rhs_shape; const DotDimensionNumbers& dim_nums = dot_info.dim_nums; - DynExpr* m_expr = lhs_shape.dimensions().size() <= 1 - ? DynExpr::one - : lhs_shape.expressions( - 1LL - dim_nums.lhs_contracting_dimensions(0)); - DynExpr* k_expr = - lhs_shape.expressions(dim_nums.lhs_contracting_dimensions(0)); - DynExpr* n_expr = rhs_shape.dimensions().size() <= 1 - ? DynExpr::one - : rhs_shape.expressions( - 1LL - dim_nums.rhs_contracting_dimensions(0)); - - return (m_expr != nullptr && m_expr->is_dynamic()) || - (k_expr != nullptr && k_expr->is_dynamic()) || - (n_expr != nullptr && n_expr->is_dynamic()); + DExpr m_expr = lhs_shape.dimensions().size() <= 1 + ? DExpr::Const(1) + : lhs_shape.expressions( + 1LL - dim_nums.lhs_contracting_dimensions(0)); + DExpr k_expr = lhs_shape.expressions(dim_nums.lhs_contracting_dimensions(0)); + DExpr n_expr = rhs_shape.dimensions().size() <= 1 + ? DExpr::Const(1) + : rhs_shape.expressions( + 1LL - dim_nums.rhs_contracting_dimensions(0)); + + return m_expr->is_dynamic() || k_expr->is_dynamic() || n_expr->is_dynamic(); } // Returns dot implementation strategy for non-batch dot operations. @@ -278,11 +275,11 @@ class DotOpEmitter { // The number of columns on the RHS. int64_t n; - DynExpr* m_expr; + DExpr m_expr; - DynExpr* k_expr; + DExpr k_expr; - DynExpr* n_expr; + DExpr n_expr; // True if the LHS matrix is column major. bool lhs_column_major; @@ -894,9 +891,12 @@ absl::Status DotOpEmitter::EmitCallToRuntime() { std::swap(transpose_lhs, transpose_rhs); } - llvm::Value* m_val = xla::llvm_ir::EmitExpression(b_, mat_mult_dims.m_expr); - llvm::Value* n_val = xla::llvm_ir::EmitExpression(b_, mat_mult_dims.n_expr); - llvm::Value* k_val = xla::llvm_ir::EmitExpression(b_, mat_mult_dims.k_expr); + llvm::Value* m_val = + xla::llvm_ir::EmitExpression(b_, mat_mult_dims.m_expr); + llvm::Value* n_val = + xla::llvm_ir::EmitExpression(b_, mat_mult_dims.n_expr); + llvm::Value* k_val = + xla::llvm_ir::EmitExpression(b_, mat_mult_dims.k_expr); b_->CreateCall( matmul_func, @@ -982,13 +982,15 @@ absl::Status DotOpEmitter::EmitCallToBatchRuntime() { std::swap(transpose_lhs, transpose_rhs); } - llvm::Value* m_val = xla::llvm_ir::EmitExpression(b_, mat_mult_dims.m_expr); - llvm::Value* n_val = xla::llvm_ir::EmitExpression(b_, mat_mult_dims.n_expr); - llvm::Value* k_val = xla::llvm_ir::EmitExpression(b_, mat_mult_dims.k_expr); - DynExpr* batch_size_expr = lhs_shape.expressions(0); - if (batch_size_expr == nullptr) { - batch_size_expr = DynExpr::_(lhs_shape.dimensions(0)); - } + llvm::Value* m_val = + xla::llvm_ir::EmitExpression(b_, mat_mult_dims.m_expr); + llvm::Value* n_val = + xla::llvm_ir::EmitExpression(b_, mat_mult_dims.n_expr); + llvm::Value* k_val = + xla::llvm_ir::EmitExpression(b_, mat_mult_dims.k_expr); + DExpr batch_size_expr = + lhs_shape.expressions(0) ? lhs_shape.expressions(0) + : DExpr::Const(lhs_shape.dimensions(0)); llvm::Value* batch_size_val = xla::llvm_ir::EmitExpression(b_, batch_size_expr); @@ -1029,12 +1031,12 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { ? 1LL : rhs_shape.dimensions(1LL - dim_nums.rhs_contracting_dimensions(0)), /*m_expr=*/lhs_shape.dimensions().size() <= 1 - ? DynExpr::one + ? DExpr::Const(1) : lhs_shape.expressions(1LL - dim_nums.lhs_contracting_dimensions(0)), /*k_expr=*/ lhs_shape.expressions(dim_nums.lhs_contracting_dimensions(0)), /*n_expr=*/rhs_shape.dimensions().size() <= 1 - ? DynExpr::one + ? DExpr::Const(1) : rhs_shape.expressions(1LL - dim_nums.rhs_contracting_dimensions(0)), /*lhs_column_major=*/is_column_major(lhs_shape), /*lhs_canonical=*/lhs_shape.dimensions().size() <= 1 || @@ -1068,12 +1070,12 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetBatchMatMultDims() const { ? 1LL : rhs_shape.dimensions(2LL - dim_nums.rhs_contracting_dimensions(0)), /*m_expr=*/lhs_shape.dimensions().size() <= 1 - ? DynExpr::one + ? DExpr::Const(1) : lhs_shape.expressions(2LL - dim_nums.lhs_contracting_dimensions(0)), /*k_expr=*/ lhs_shape.expressions(1LL + dim_nums.lhs_contracting_dimensions(0)), /*n_expr=*/rhs_shape.dimensions().size() <= 1 - ? DynExpr::one + ? DExpr::Const(1) : rhs_shape.expressions(2LL - dim_nums.rhs_contracting_dimensions(0)), /*lhs_column_major=*/is_column_major(lhs_shape), /*lhs_canonical=*/lhs_shape.dimensions().size() <= 1 || @@ -1154,7 +1156,7 @@ absl::Status EmitNonBatchDotOperation( Shape DropFirstDim(const Shape& shape) { absl::Span array_shape_dims(shape.dimensions()); - absl::Span array_shape_exprs(shape.expressions()); + absl::Span array_shape_exprs(shape.expressions()); array_shape_dims.remove_prefix(1); array_shape_exprs.remove_prefix(1); return ShapeUtil::MakeShapeWithDescendingLayout( @@ -1163,19 +1165,19 @@ Shape DropFirstDim(const Shape& shape) { Shape CollapseFirstNDims(const Shape& shape, int64_t n) { absl::Span input_shape_dims(shape.dimensions()); - absl::Span input_expressions(shape.expressions()); + absl::Span input_expressions(shape.expressions()); int64_t prefix_dim = std::accumulate(input_shape_dims.begin(), input_shape_dims.begin() + n, 1ll, std::multiplies()); - DynExpr* prefix_expression = std::accumulate( - input_expressions.begin(), input_expressions.begin() + n, DynExpr::one, - [](DynExpr* acc, DynExpr* v) { return (*acc) * (*v); }); + DExpr prefix_expression = std::accumulate( + input_expressions.begin(), input_expressions.begin() + n, DExpr::Const(1), + [](DExpr acc, const DExpr& v) { return acc * v; }); DimensionVector result_dims; - std::vector result_expressions; + std::vector result_expressions; result_dims.push_back(prefix_dim); - result_expressions.push_back(prefix_expression->s()); + result_expressions.push_back(prefix_expression.simplify()); std::copy(input_shape_dims.begin() + n, input_shape_dims.end(), std::back_inserter(result_dims)); std::copy(input_expressions.begin() + n, input_expressions.end(), diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index 87a14be281a93e..c72fcbf0a1cc84 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -76,6 +76,7 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/map_util.h" +#include "xla/printer.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" @@ -2067,7 +2068,7 @@ absl::Status IrEmitter::HandleSlice(HloInstruction* slice) { const int64_t memcpy_elements = primitive_elements_per_logical_element * memcpy_logical_elements; - EmitTransferElements(memcpy_dest, memcpy_source, DynExpr::_(memcpy_elements), + EmitTransferElements(memcpy_dest, memcpy_source, DExpr::Const(memcpy_elements), slice->shape().element_type(), target_array, source_array); @@ -2362,9 +2363,8 @@ absl::Status IrEmitter::HandleShapeExprValue(HloInstruction* hlo) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); llvm_ir::IrArray out_array = GetIrArrayFor(hlo); - - llvm::Value* expr_value = - llvm_ir::EmitExpression(b(), hlo->operand(0)->shape().expressions(0)); + const auto& expr = hlo->operand(0)->shape().expressions(0); + llvm::Value* expr_value = llvm_ir::EmitExpression(b(), expr); auto it = emitted_value_.find(hlo); if (it == emitted_value_.end()) { @@ -3149,9 +3149,9 @@ absl::Status EmitFastConcatenate( target_array.EmitArrayElementAddress(target_index, &b, "target_region"); llvm::Value* byte_offset_into_target_region = b.getInt64(0); - DynExpr* inner_exprs_product = absl::c_accumulate( - inner_dims, DynExpr::one, [&](DynExpr* product, int64_t inner_dim) { - return *product * *output_shape.expressions(inner_dim); + DExpr inner_exprs_product = absl::c_accumulate( + inner_dims, DExpr::Const(1), [&](DExpr product, int64_t inner_dim) { + return product * output_shape.expressions(inner_dim); }); // For each operand, emit a memcpy from the operand to the target of size @@ -3169,13 +3169,14 @@ absl::Status EmitFastConcatenate( auto cexpr = input_shape.expressions(concat_dim); - ::xla::cpu::EmitTransferElements(copy_target_address, copy_source_address, - (*inner_exprs_product * *cexpr)->s(), - primitive_type, target_array, source_array, - module, b); + ::xla::cpu::EmitTransferElements( + copy_target_address, copy_source_address, + (inner_exprs_product * cexpr).simplify(), primitive_type, target_array, + source_array, module, b); llvm::Value* concat_dim_count = xla::llvm_ir::EmitExpression( - &b, (*inner_exprs_product * *input_shape.expressions(concat_dim))->s()); + &b, (inner_exprs_product * input_shape.expressions(concat_dim)) + .simplify()); llvm::Value* concat_dim_size = b.CreateMul(concat_dim_count, b.getInt64(primitive_type_size)); @@ -3392,7 +3393,7 @@ llvm::Value* IrEmitter::EmitCallToFfi(HloCustomCallInstruction* custom_call, } void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, - xla::DynExpr* element_count, + const xla::DExpr& element_count, PrimitiveType primitive_type, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& source_array) { @@ -3402,7 +3403,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, } void EmitTransferElements(llvm::Value* target, llvm::Value* source, - xla::DynExpr* element_count, + const xla::DExpr& element_count, PrimitiveType primitive_type, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& source_array, @@ -3415,7 +3416,7 @@ void EmitTransferElements(llvm::Value* target, llvm::Value* source, llvm::Type* primitive_llvm_type = llvm_ir::PrimitiveTypeToIrType(primitive_type, module->getContext()); - if (element_count == DynExpr::one) { + if (element_count->is_constant() && element_count->get_val() == 1) { auto* load_instruction = b.CreateAlignedLoad(primitive_llvm_type, source, element_alignment); source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction); @@ -4098,21 +4099,21 @@ absl::Status IrEmitter::EmitMemcpy(const HloInstruction& source, auto expressions = shape.expressions(); bool is_dynamic = std::any_of(expressions.begin(), expressions.end(), - [](DynExpr* e) { return e->is_dynamic(); }); + [](const DExpr& e) { return e && e->is_dynamic(); }); if (is_dynamic) { llvm::LLVMContext& ctx = b()->getContext(); llvm::IntegerType* i64Type = llvm::IntegerType::getInt64Ty(ctx); int64_t dimensions_accu = 1; - DynExpr* expression_accu = DynExpr::one; + DExpr expression_accu = DExpr::Const(1); for (int i = 0; i < shape.dimensions_size(); i++) { - auto expression = shape.expressions(i); - if (expression->is_dynamic()) { + const auto& expression = shape.expressions(i); + if (expression && expression->is_dynamic()) { dimensions_accu *= shape.dimensions(i); - expression_accu = (*expression_accu) * (*expression); + expression_accu = expression_accu * expression; } } llvm::Value* expr_value = - xla::llvm_ir::EmitExpression(b(), expression_accu->s()); + xla::llvm_ir::EmitExpression(b(), expression_accu.simplify()); // Divide the size in bytes by the size of the dynamic dimension(s). // TODO: make that less hacky llvm::ConstantInt* size = diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h index f4069fdabec932..53583810acd5cb 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.h +++ b/third_party/xla/xla/service/cpu/ir_emitter.h @@ -570,7 +570,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Emits LLVM IR to transfer "element_count" elements of type "primitive_type" // from the address "source" to the address "target". void EmitTransferElements(llvm::Value* target, llvm::Value* source, - xla::DynExpr* element_count, PrimitiveType primitive_type, + const xla::DExpr& element_count, + PrimitiveType primitive_type, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& source_array); @@ -860,7 +861,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Decoupled implementation of IrEmitter::EmitTransferElements. void EmitTransferElements(llvm::Value* target, llvm::Value* source, - xla::DynExpr* element_count, PrimitiveType primitive_type, + const xla::DExpr& element_count, + PrimitiveType primitive_type, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& source_array, llvm::Module* module, llvm::IRBuilderBase& b); diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.cc b/third_party/xla/xla/service/cpu/ir_emitter2.cc index a5d8fbeacbe19d..d09991912d9db9 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter2.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter2.cc @@ -74,6 +74,7 @@ limitations under the License. #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/service/llvm_ir/loop_emitter.h" +#include "xla/printer.h" #include "xla/shape.h" #include "xla/shape_partition.h" #include "xla/stream_executor/launch_dim.h" @@ -203,7 +204,7 @@ IrEmitter2::EmitGetExpressionValueHostKernel(const HloInstruction* getBatch) { EmitKernelPrototype(getBatch)); llvm_ir::IrArray operand_array = kernel_prototype.arguments[0]; llvm_ir::IrArray output_array = kernel_prototype.results[0]; - xla::DynExpr* expr = getBatch->operand(0)->shape().expressions(0); + const auto& expr = getBatch->operand(0)->shape().expressions(0); llvm::IRBuilder<> b(module_->getContext()); b.SetInsertPoint(kernel_prototype.function->getEntryBlock().getTerminator()); llvm::Value* bdim_value = llvm_ir::EmitExpression(&b, expr); diff --git a/third_party/xla/xla/service/dynamic_constant_rewriter.cc b/third_party/xla/xla/service/dynamic_constant_rewriter.cc index d8b89dc7218d5f..315c9d6615f3ba 100644 --- a/third_party/xla/xla/service/dynamic_constant_rewriter.cc +++ b/third_party/xla/xla/service/dynamic_constant_rewriter.cc @@ -25,34 +25,6 @@ namespace xla { namespace { -DynExpr* DynExprFromProto(const ExpressionProto& proto) { - switch (proto.node_type_case()) { - case ExpressionProto::kConstantValue: - return DynExpr::_(proto.constant_value()); - case ExpressionProto::kVariableId: - return DynExpr::V(proto.variable_id()); - case ExpressionProto::kAddNode: { - const auto& add = proto.add_node(); - return new Add(DynExprFromProto(add.lhs()), DynExprFromProto(add.rhs())); - } - case ExpressionProto::kSubNode: { - const auto& sub = proto.sub_node(); - return new Sub(DynExprFromProto(sub.lhs()), DynExprFromProto(sub.rhs())); - } - case ExpressionProto::kMulNode: { - const auto& mul = proto.mul_node(); - return new Mul(DynExprFromProto(mul.lhs()), DynExprFromProto(mul.rhs())); - } - case ExpressionProto::kDivNode: { - const auto& div = proto.div_node(); - return new Div(DynExprFromProto(div.lhs()), DynExprFromProto(div.rhs())); - } - case ExpressionProto::NODE_TYPE_NOT_SET: - default: - return nullptr; - } -} - absl::StatusOr BuildDynamicConstantReplacement( HloInstruction* constant_instr) { TF_RET_CHECK(constant_instr->opcode() == HloOpcode::kConstant); @@ -72,8 +44,8 @@ absl::StatusOr BuildDynamicConstantReplacement( TF_RET_CHECK(tsl::protobuf::TextFormat::ParseFromString(expr_it->second, &expr_proto)) << "Failed to parse dynamic_constant_expr=" << expr_it->second; - DynExpr* expr = DynExprFromProto(expr_proto); - TF_RET_CHECK(expr != nullptr); + DExpr expr = DExprFromProto(expr_proto); + TF_RET_CHECK(expr); const Shape& shape = constant_instr->shape(); TF_RET_CHECK(shape.IsArray()); @@ -103,7 +75,7 @@ absl::StatusOr BuildDynamicConstantReplacement( HloComputation* computation = constant_instr->parent(); Shape carrier_shape = ShapeUtil::MakeShape(S32, {1}); - carrier_shape.set_expression(0, expr); + carrier_shape.set_expression(0, std::move(expr)); HloInstruction* carrier = computation->AddInstruction( HloInstruction::CreateConstant( LiteralUtil::CreateR1( diff --git a/third_party/xla/xla/service/elemental_ir_emitter.cc b/third_party/xla/xla/service/elemental_ir_emitter.cc index b2f20ebcacdf0a..88e43003ae4f4f 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter.cc @@ -3289,8 +3289,8 @@ absl::StatusOr ElementalIrEmitter::EmitElementalConcatenate( cases.emplace_back(current_offset, operand); llvm::Value* cdim = source_index.GetConstantWithIndexType( operand->shape().dimensions(concat_dim)); - xla::DynExpr* concat_expr = operand->shape().expressions(concat_dim); - if (concat_expr != nullptr && concat_expr->is_dynamic()) { + const auto& concat_expr = operand->shape().expressions(concat_dim); + if (concat_expr->is_dynamic()) { cdim = llvm_ir::EmitExpression(b_, concat_expr); } current_offset = b_->CreateAdd(current_offset, cdim, "current_offset"); @@ -3626,8 +3626,8 @@ absl::StatusOr ElementalIrEmitter::EmitElementalPad( int64_t shape_dim = hlo->operand(0)->shape().dimensions(i); llvm::Value* bound = index_typed_const(shape_dim); - xla::DynExpr* operand_expr = hlo->operand(0)->shape().expressions(i); - if (operand_expr != nullptr && operand_expr->is_dynamic()) { + const auto& operand_expr = hlo->operand(0)->shape().expressions(i); + if (operand_expr->is_dynamic()) { bound = llvm_ir::EmitExpression(b_, operand_expr); } @@ -3905,8 +3905,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); std::vector source_multi_index = target_index.multidim(); for (int64_t dim : hlo->dimensions()) { - xla::DynExpr* dim_expr = hlo->shape().expressions(dim); - if (dim_expr != nullptr && dim_expr->is_dynamic()) { + const auto& dim_expr = hlo->shape().expressions(dim); + if (dim_expr->is_dynamic()) { llvm::Value* one = target_index.GetConstantWithIndexType(1); llvm::Value* expr_value = llvm_ir::EmitExpression(b_, dim_expr); source_multi_index[dim] = @@ -4270,9 +4270,9 @@ absl::StatusOr ElementalIrEmitter::EmitElementalReduceWindow( int64_t dim_bound = reduce_window->inputs()[0]->shape().dimensions(i); llvm::Value* shape_bound = index_typed_const(dim_bound); - xla::DynExpr* window_expr = + const auto& window_expr = reduce_window->inputs()[0]->shape().expressions(i); - if (window_expr != nullptr && window_expr->is_dynamic()) { + if (window_expr->is_dynamic()) { llvm::Value* expr_value = llvm_ir::EmitExpression(b_, window_expr); shape_bound = expr_value; } diff --git a/third_party/xla/xla/service/hlo_creation_utils.cc b/third_party/xla/xla/service/hlo_creation_utils.cc index 3a33c5f89a2e36..25cecfdc0b6d8e 100644 --- a/third_party/xla/xla/service/hlo_creation_utils.cc +++ b/third_party/xla/xla/service/hlo_creation_utils.cc @@ -184,7 +184,8 @@ absl::StatusOr MakeDynamicSliceHlo( TF_ASSIGN_OR_RETURN( Shape dynamic_slice_shape, ShapeInference::InferDynamicSliceShape( - operand->shape(), scalar_start_indices_shapes, slice_sizes)); + operand->shape(), scalar_start_indices_shapes, slice_sizes, + operand->shape().expressions())); return computation->AddInstruction( HloInstruction::CreateDynamicSlice(dynamic_slice_shape, operand, start_indices, slice_sizes), @@ -655,12 +656,12 @@ absl::StatusOr CollapseFirstNDims(HloInstruction* operand, CHECK_GE(operand_shape.dimensions_size(), n); int64_t new_shape_leading_bound = 1; bool new_shape_leading_is_dynamic = false; - DynExpr* new_shape_leading_expression = DynExpr::one; + DExpr new_shape_leading_expression = DExpr::Const(1); for (int64_t i = 0; i < n; i++) { new_shape_leading_bound *= operand_shape.dimensions(i); new_shape_leading_is_dynamic |= operand_shape.is_dynamic_dimension(i); new_shape_leading_expression = - (*new_shape_leading_expression) * (*operand_shape.expressions(i)); + new_shape_leading_expression * operand_shape.expressions(i); } std::vector new_shape_dims; @@ -678,9 +679,9 @@ absl::StatusOr CollapseFirstNDims(HloInstruction* operand, operand_shape.dynamic_dimensions().end(), std::back_inserter(new_shape_dynamic_dims)); - std::vector new_shape_expressions; + std::vector new_shape_expressions; new_shape_expressions.reserve(operand_shape.dimensions_size() - n + 1); - new_shape_expressions.push_back(new_shape_leading_expression->s()); + new_shape_expressions.push_back(new_shape_leading_expression.simplify()); auto exprs = operand_shape.expressions(); std::copy(exprs.begin() + n, exprs.end(), std::back_inserter(new_shape_expressions)); diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.cc b/third_party/xla/xla/service/llvm_ir/ir_array.cc index 43686289255bd2..7748c61b4e67b8 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_array.cc +++ b/third_party/xla/xla/service/llvm_ir/ir_array.cc @@ -281,8 +281,8 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( // linear index by each dimension size. for (int64_t i = common_factors[k + 1].first - 1; i >= common_factors[k].first; --i) { - xla::DynExpr* input_expr = input_shape.expressions(i); - bool is_dynamic = input_expr != nullptr && input_expr->is_dynamic(); + const auto& input_expr = input_shape.expressions(i); + bool is_dynamic = input_expr->is_dynamic(); llvm::Value* divisor = is_dynamic ? llvm_ir::EmitExpression(builder, input_expr) : GetConstantWithIndexType(input_shape.dimensions(i)); @@ -567,7 +567,7 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, // it's always indiced with 0 (i.e. the dynamic dimension has no impact on the // address computation). std::vector gep_dims; - std::vector gep_expressions; + std::vector gep_expressions; gep_dims.reserve(shape_.dimensions().size()); gep_expressions.reserve(shape_.dimensions().size()); for (int64_t i = 0; i < shape_.dimensions().size(); ++i) { @@ -578,7 +578,7 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, bool dynamic_first_dim = gep_expressions[0]->is_dynamic() && std::all_of(gep_expressions.begin() + 1, gep_expressions.end(), - [](DynExpr* e) { return e->is_constant(); }); + [](const DExpr& e) { return e->is_constant(); }); if (!dynamic_first_dim && shape_.has_dynamic_expr()) { llvm::Type* element_type = PrimitiveTypeToIrType(shape_.element_type(), b->getContext()); diff --git a/third_party/xla/xla/service/llvm_ir/llvm_loop.cc b/third_party/xla/xla/service/llvm_ir/llvm_loop.cc index 562260717ba35d..076027184ba66f 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_loop.cc +++ b/third_party/xla/xla/service/llvm_ir/llvm_loop.cc @@ -189,7 +189,7 @@ llvm::BasicBlock* ForLoop::CreateLoopBB(absl::string_view name, std::unique_ptr ForLoopNest::AddLoop( absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, UnrollMode unroll_mode, bool prevent_vectorization, - DynExpr* expression) { + DExpr expression) { return AddLoop(suffix, start_index, end_index, GetConstantWithIndexType(1), unroll_mode, prevent_vectorization, expression); } @@ -197,7 +197,7 @@ std::unique_ptr ForLoopNest::AddLoop( std::unique_ptr ForLoopNest::AddLoop( absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, UnrollMode unroll_mode, bool prevent_vectorization, - DynExpr* expression) { + DExpr expression) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. b_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); @@ -205,8 +205,7 @@ std::unique_ptr ForLoopNest::AddLoop( llvm::Value* actual_end = end_index; if (expression && expression->is_dynamic()) { // Get batch dim and compare with end_index to use minimum value - llvm::Value* expr_value = - llvm_ir::EmitExpression(b_, expression); + llvm::Value* expr_value = llvm_ir::EmitExpression(b_, expression); actual_end = b_->CreateSelect(b_->CreateICmpULT(end_index, expr_value), end_index, expr_value, "loop_end_min"); } @@ -231,7 +230,7 @@ std::unique_ptr ForLoopNest::AddLoop( std::unique_ptr ForLoopNest::AddLoop( int64_t start_index, int64_t end_index, absl::string_view suffix, UnrollMode unroll_mode, bool prevent_vectorization, - DynExpr* expression) { + DExpr expression) { CHECK_LE(start_index, end_index); llvm::Value* end = (expression && expression->is_dynamic()) @@ -246,7 +245,7 @@ std::unique_ptr ForLoopNest::AddLoop(int64_t start_index, absl::string_view suffix, UnrollMode unroll_mode, bool prevent_vectorization, - DynExpr* expression) { + DExpr expression) { CHECK_LE(start_index, end_index); llvm::Value* end = (expression && expression->is_dynamic()) diff --git a/third_party/xla/xla/service/llvm_ir/llvm_loop.h b/third_party/xla/xla/service/llvm_ir/llvm_loop.h index 5414f019121d7d..763bbf2ab7b642 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_loop.h +++ b/third_party/xla/xla/service/llvm_ir/llvm_loop.h @@ -201,14 +201,14 @@ class ForLoopNest { absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, - bool prevent_vectorization = false, DynExpr* expression = nullptr); + bool prevent_vectorization = false, DExpr expression = DExpr()); // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop( absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, - bool prevent_vectorization = false, DynExpr* expression = nullptr); + bool prevent_vectorization = false, DExpr expression = DExpr()); // A convenient wrapper of the other flavor of AddLoop. The given start and // end index are constant. @@ -216,13 +216,13 @@ class ForLoopNest { int64_t start_index, int64_t end_index, int64_t stride, absl::string_view suffix, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, - bool prevent_vectorization = false, DynExpr* expression = nullptr); + bool prevent_vectorization = false, DExpr expression = DExpr()); // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop( int64_t start_index, int64_t end_index, absl::string_view suffix, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, - bool prevent_vectorization = false, DynExpr* expression = nullptr); + bool prevent_vectorization = false, DExpr expression = DExpr()); // Add loops to iterate through the indices within the specified // shape. The returned index collects the induction variables of the diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.cc b/third_party/xla/xla/service/llvm_ir/llvm_util.cc index fa27de28a84201..d5ce4654e460be 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.cc +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.cc @@ -68,6 +68,7 @@ limitations under the License. #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/primitive_util.h" +#include "xla/printer.h" #include "xla/service/cpu/cpu_options.h" #include "xla/service/dump.h" #include "xla/service/hlo_module_config.h" @@ -881,46 +882,61 @@ llvm::Value* GetBatchDimByName(llvm::IRBuilderBase* b, int64_t multiplier, return bdim_scaled; } -llvm::Value* EmitExpression(llvm::IRBuilderBase* b, DynExpr* expr) { - llvm::Function* function = b->GetInsertBlock()->getParent(); +static llvm::Value* EmitExpressionImpl(llvm::IRBuilderBase* b, + const DynExpr& expr) { llvm::LLVMContext& ctx = b->getContext(); llvm::IntegerType* i64Type = llvm::IntegerType::getInt64Ty(ctx); - if (expr == nullptr) return nullptr; - if (expr->is_constant()) - return llvm::ConstantInt::get(i64Type, expr->get_val(), true); - if (Variable* var_node = dynamic_cast(expr)) { + if (expr.is_constant()) return llvm::ConstantInt::get(i64Type, expr.get_val(), true); + if (expr.kind() == DExpr::Kind::kUnknown) { + return nullptr; + } + if (expr.kind() == DExpr::Kind::kVariable) { // For now we can just use %bdim... return GetBatchDimByName(b); } - if (Mul* mul_node = dynamic_cast(expr)) { - llvm::Value* v_lhs = EmitExpression(b, mul_node->get_lhs()); - llvm::Value* v_rhs = EmitExpression(b, mul_node->get_rhs()); + if (expr.kind() == DExpr::Kind::kMul) { + auto* mul_node = static_cast(&expr); + llvm::Value* v_lhs = EmitExpressionImpl(b, *mul_node->get_lhs()); + llvm::Value* v_rhs = EmitExpressionImpl(b, *mul_node->get_rhs()); return b->CreateMul(v_lhs, v_rhs, "mul_dims"); } // TODO: Check if this should ever happen - if (Div* div_node = dynamic_cast(expr)) { - llvm::Value* v_lhs = EmitExpression(b, div_node->get_lhs()); - llvm::Value* v_rhs = EmitExpression(b, div_node->get_rhs()); + if (expr.kind() == DExpr::Kind::kDiv) { + auto* div_node = static_cast(&expr); + llvm::Value* v_lhs = EmitExpressionImpl(b, *div_node->get_lhs()); + llvm::Value* v_rhs = EmitExpressionImpl(b, *div_node->get_rhs()); return b->CreateUDiv(v_lhs, v_rhs, "div_dims"); } - if (Add* add_node = dynamic_cast(expr)) { - llvm::Value* v_lhs = EmitExpression(b, add_node->get_lhs()); - llvm::Value* v_rhs = EmitExpression(b, add_node->get_rhs()); + if (expr.kind() == DExpr::Kind::kAdd) { + auto* add_node = static_cast(&expr); + llvm::Value* v_lhs = EmitExpressionImpl(b, *add_node->get_lhs()); + llvm::Value* v_rhs = EmitExpressionImpl(b, *add_node->get_rhs()); return b->CreateAdd(v_lhs, v_rhs, "add_dims"); } - if (Sub* sub_node = dynamic_cast(expr)) { - llvm::Value* v_lhs = EmitExpression(b, sub_node->get_lhs()); - llvm::Value* v_rhs = EmitExpression(b, sub_node->get_rhs()); + if (expr.kind() == DExpr::Kind::kSub) { + auto* sub_node = static_cast(&expr); + llvm::Value* v_lhs = EmitExpressionImpl(b, *sub_node->get_lhs()); + llvm::Value* v_rhs = EmitExpressionImpl(b, *sub_node->get_rhs()); return b->CreateSub(v_lhs, v_rhs, "sub_dims"); } return nullptr; } +llvm::Value* EmitExpression(llvm::IRBuilderBase* b, const DExpr& expr) { + if (!expr) return nullptr; + StringPrinter printer; + expr->print(&printer); + VLOG(2) << "EmitExpression expr=" << std::move(printer).ToString() + << " kind=" << static_cast(expr.kind()); + llvm::Value* value = EmitExpressionImpl(b, *expr.get()); + return value; +} + llvm::Value* createDynamicGEP(llvm::IRBuilderBase* builder, llvm::Value* base_ptr, const std::vector& indices, absl::Span dims, - absl::Span expressions, + absl::Span expressions, llvm::Type* elem_type, const llvm::Twine& name) { llvm::Value* total_index = builder->getInt64(0); @@ -930,7 +946,7 @@ llvm::Value* createDynamicGEP(llvm::IRBuilderBase* builder, // The stride is the product of all dimensions to the right of this index. llvm::Value* stride = builder->getInt64(1); for (size_t j = i; j < dims.size(); ++j) { - if (expressions[j]->is_dynamic()) { + if (expressions[j] && expressions[j]->is_dynamic()) { llvm::Value* expr_value = EmitExpression(builder, expressions[j]); stride = builder->CreateMul(stride, expr_value, "stride.dyn"); @@ -945,7 +961,8 @@ llvm::Value* createDynamicGEP(llvm::IRBuilderBase* builder, } // Final GEP: result = base + total_index * sizeof(elem_type) - return builder->CreateGEP(elem_type, base_ptr, total_index, name); + llvm::Value* gep = builder->CreateGEP(elem_type, base_ptr, total_index, name); + return gep; } } // namespace llvm_ir diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.h b/third_party/xla/xla/service/llvm_ir/llvm_util.h index c868c81574882f..77bd543c9c0482 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.h +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.h @@ -341,11 +341,11 @@ llvm::Value* createDynamicGEP(llvm::IRBuilderBase* builder, llvm::Value* base_ptr, const std::vector& indices, absl::Span dims, - absl::Span expressions, + absl::Span expressions, llvm::Type* elem_type, const llvm::Twine& name = ""); -llvm::Value* EmitExpression(llvm::IRBuilderBase* b, DynExpr* expr); +llvm::Value* EmitExpression(llvm::IRBuilderBase* b, const DExpr& expr); } // namespace llvm_ir } // namespace xla diff --git a/third_party/xla/xla/service/llvm_ir/loop_emitter.cc b/third_party/xla/xla/service/llvm_ir/loop_emitter.cc index 4156e37bf4c5dd..20c2b9098cec17 100644 --- a/third_party/xla/xla/service/llvm_ir/loop_emitter.cc +++ b/third_party/xla/xla/service/llvm_ir/loop_emitter.cc @@ -200,7 +200,7 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( bool dynamic = false; for (int i = 0; i < shape_.dimensions_size(); i++) { auto expr = shape_.expressions(i); - if (expr != nullptr && expr->is_dynamic()) { + if (expr && expr->is_dynamic()) { dynamic_dims[i] = xla::llvm_ir::EmitExpression(b_, expr); shape_.set_dynamic_dimension(i, true); dynamic = true; diff --git a/third_party/xla/xla/service/shape_inference.cc b/third_party/xla/xla/service/shape_inference.cc index 2d0cc3b27cc123..a5a07e6b5ef49d 100644 --- a/third_party/xla/xla/service/shape_inference.cc +++ b/third_party/xla/xla/service/shape_inference.cc @@ -195,7 +195,7 @@ absl::StatusOr InferWindowOutputShape(const Shape& base_shape, std::vector output_dimensions(window.dimensions_size()); std::vector output_is_dynamic(window.dimensions_size()); - std::vector output_expressions(window.dimensions_size()); + std::vector output_expressions(window.dimensions_size()); for (int64_t i = 0; i < window.dimensions_size(); ++i) { const auto& dim = window.dimensions(i); if (dim.size() <= 0) { @@ -475,7 +475,7 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, int64_t last_dim = operand_shape.dimensions_size() - 1; std::vector is_dynamic(operand_shape.dimensions_size()); std::vector dimensions(operand_shape.dimensions_size()); - std::vector expressions(operand_shape.dimensions_size()); + std::vector expressions(operand_shape.dimensions_size()); TF_RET_CHECK(operand_shape.dimensions(last_dim) >= k) << "k=" << k << " is larger than the last dimension of size=" @@ -485,7 +485,7 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, i == last_dim ? false : operand_shape.is_dynamic_dimension(i); dimensions[i] = i == last_dim ? k : operand_shape.dimensions(i); expressions[i] = - i == last_dim ? xla::DynExpr::_(k) : operand_shape.expressions(i); + i == last_dim ? xla::DExpr::Const(k) : operand_shape.expressions(i); } Shape out = @@ -549,11 +549,11 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, int64_t rank = arg_shape->dimensions_size(); std::vector inferred_sizes(rank, Shape::kUnboundedSize); std::vector inferred_bounds(rank, Shape::kUnboundedSize); - std::vector inferred_expressions(rank, DynExpr::zero); + std::vector inferred_expressions(rank, xla::DExpr::Const(0)); // Note: for the concatenate dimension, 0 should be the identity element: // Any dim size can keep unchanged when concatenated with 0 inferred_sizes[dimension] = 0; - inferred_expressions[dimension] = DynExpr::zero; + inferred_expressions[dimension] = xla::DExpr::Const(0); for (const Shape* shape : arg_shapes) { for (int dim = 0; dim < rank; ++dim) { @@ -563,27 +563,27 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, int64_t leftSize = inferred_sizes[dim]; int64_t rightSize = dimension_size; int64_t leftBound = inferred_bounds[dim]; - xla::DynExpr* leftExpression = inferred_expressions[dim]; + const xla::DExpr& left_expression = inferred_expressions[dim]; int64_t rightBound = shape->is_dynamic_dimension(dim) ? dimension_size : Shape::kUnboundedSize; - xla::DynExpr* rightExpression = shape->expressions(dim); - xla::DynExpr* inferred_expression = xla::DynExpr::zero; + const xla::DExpr& right_expression = shape->expressions(dim); + xla::DExpr inferred_expression = xla::DExpr::Const(0); if (dim == dimension) { inferred_dim_and_bound = InferConcatenatedDimAndBound( leftSize, rightSize, leftBound, rightBound); - inferred_expression = *leftExpression + *rightExpression; + inferred_expression = left_expression + right_expression; } else { TF_ASSIGN_OR_RETURN( inferred_dim_and_bound, InferMostSpecificDimAndBound(dim, leftSize, rightSize, leftBound, rightBound)); - inferred_expression = rightExpression; + inferred_expression = right_expression; } inferred_sizes[dim] = inferred_dim_and_bound.dimension; inferred_bounds[dim] = inferred_dim_and_bound.bound; - inferred_expressions[dim] = inferred_expression->s(); + inferred_expressions[dim] = inferred_expression.simplify(); } } @@ -773,7 +773,7 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, std::vector dimensions(operand_shape.dimensions_size()); std::vector is_dynamic(operand_shape.dimensions_size()); - std::vector expressions(operand_shape.dimensions_size()); + std::vector expressions(operand_shape.dimensions_size()); for (int64_t i = 0; i < operand_shape.dimensions_size(); ++i) { const auto& p = padding_config.dimensions(i); if (operand_shape.is_unbounded_dynamic_dimension(i)) { @@ -790,7 +790,7 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, } is_dynamic[i] = operand_shape.is_dynamic_dimension(i); auto diff = dimensions[i] - operand_shape.dimensions(i); - expressions[i] = (*operand_shape.expressions(i) + diff)->s(); + expressions[i] = (operand_shape.expressions(i) + diff).simplify(); } return ShapeUtil::MakeShape( @@ -941,7 +941,7 @@ void GenerateDotResultDimensions( const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers, std::vector& dimensions, - std::vector& expressions, + std::vector& expressions, std::vector& is_dynamic, std::vector rhs_group_dimensions = {}) { const auto& lhs_batch_dimensions = dimension_numbers.lhs_batch_dimensions(); @@ -1016,7 +1016,7 @@ void GenerateDotResultDimensions( std::vector dimensions; std::vector is_dynamic; - std::vector expressions; + std::vector expressions; GenerateDotResultDimensions(lhs, rhs, dimension_numbers, dimensions, expressions, is_dynamic); @@ -1221,7 +1221,7 @@ void GenerateDotResultDimensions( PrimitiveType type = preferred_element_type.value_or( ShapeUtil::HigherPrecisionElementType(lhs, rhs)); std::vector dimensions; - std::vector expressions; + std::vector expressions; std::vector is_dynamic; // Add the group dimension to the result shape in case of ragged contracting. if (mode == kContracting) { @@ -1278,7 +1278,7 @@ void GenerateDotResultDimensions( // Build the resulting shape dimensions. std::vector dimensions; std::vector is_dynamic; - std::vector expressions; + std::vector expressions; for (int64_t i = 0; i < operand_shape.dimensions_size(); ++i) { dimensions.push_back(i != sparsity.dimension() ? operand_shape.dimensions(i) : metadata_dimension_size); @@ -1300,7 +1300,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(const Shape& lhs, // from the lhs/rhs pair in every index. std::vector output_dimensions(lhs.dimensions_size()); std::vector output_dimensions_is_dynamic(lhs.dimensions_size()); - std::vector output_dimensions_expressions(lhs.dimensions_size()); + std::vector output_dimensions_expressions(lhs.dimensions_size()); for (int64_t i = 0; i < lhs.dimensions_size(); ++i) { if (lhs.dimensions(i) == 1 || rhs.dimensions(i) == 1) { // For the unbounded case, the operand with 1 should be broadcasted to the @@ -1441,7 +1441,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(const Shape& lhs, dimension_to_match, larger_shape.dimensions_size())); } int64_t small_dimension_size = smaller_shape.dimensions(i); - DynExpr* small_dimension_exp = smaller_shape.expressions(i); + const DExpr& small_dimension_exp = smaller_shape.expressions(i); int64_t large_dimension_size = larger_shape.dimensions(dimension_to_match); bool small_is_dynamic = smaller_shape.is_dynamic_dimension(i); bool large_is_dynamic = @@ -1911,12 +1911,12 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { const int64_t feature_count = operand_shape.dimensions(feature_index); bool dynamic_feature = operand_shape.is_dynamic_dimension(feature_index); - DynExpr* expression_feature = - operand_shape.expressions(feature_index); + const DExpr& expression_feature = operand_shape.expressions(feature_index); + std::array feature_expressions = {expression_feature}; Shape output_shape_for_mean_and_var = ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count}, - {dynamic_feature}, {expression_feature}); + {dynamic_feature}, feature_expressions); if (!CompatibleDimensionSizes(ShapeUtil::GetDimension(offset_shape, 0), feature_count)) { @@ -2199,12 +2199,12 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { const int64_t feature_count = operand_shape.dimensions(feature_index); bool dynamic_feature = operand_shape.is_dynamic_dimension(feature_index); - DynExpr* expression_feature = - operand_shape.expressions(feature_index); + const DExpr& expression_feature = operand_shape.expressions(feature_index); + std::array feature_expressions = {expression_feature}; Shape feature_shape = ShapeUtil::MakeShape( operand_shape.element_type(), {feature_count}, {dynamic_feature}, - {expression_feature}); + feature_expressions); if (!CompatibleDimensionSizes(ShapeUtil::GetDimension(mean_shape, 0), feature_count)) { @@ -2453,13 +2453,12 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { } std::vector dynamic_dimensions(input_spatial_dims.size()); - std::vector expressions(input_spatial_dims.size()); + std::vector expressions(input_spatial_dims.size()); for (auto it = input_spatial_dims.begin(); it != input_spatial_dims.end(); ++it) { dynamic_dimensions[it - input_spatial_dims.begin()] = IsUnboundedDynamicSize(*it); - expressions[it - input_spatial_dims.begin()] = - DynExpr::_(-70); + expressions[it - input_spatial_dims.begin()] = DExpr::Unknown(70); } Shape base_shape = ShapeUtil::MakeShape( lhs.element_type(), input_spatial_dims, dynamic_dimensions, @@ -2841,7 +2840,7 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { const std::vector dynamic_dimensions(shape.dynamic_dimensions().begin(), shape.dynamic_dimensions().end()); auto exprs = shape.expressions(); - std::vector expressions(exprs.begin(), exprs.end()); + std::vector expressions(exprs.begin(), exprs.end()); return ShapeUtil::MakeShape(shape.element_type(), new_dimensions, dynamic_dimensions, expressions); } @@ -2985,7 +2984,7 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { std::vector new_dimensions; std::vector new_is_dynamic; - std::vector new_expressions; + std::vector new_expressions; for (int i = 0; i < arg.dimensions_size(); ++i) { if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) { new_dimensions.push_back(arg.dimensions(i)); @@ -3240,8 +3239,8 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { /* static */ absl::StatusOr ShapeInference::InferSliceShape( const Shape& arg, absl::Span starts, absl::Span limits, absl::Span strides, - absl::Span start_exprs, - absl::Span limit_exprs) { + absl::Span start_exprs, + absl::Span limit_exprs) { auto error = [&](const std::string& message) { return InvalidArgument( "%s in slice operation; argument shape: %s; starts: {%s}; limits: " @@ -3271,7 +3270,7 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { } std::vector sizes; - std::vector expressions; + std::vector expressions; const auto starts_size = starts.size(); sizes.reserve(starts_size); expressions.reserve(starts_size); @@ -3303,13 +3302,15 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { } sizes.push_back((limit_index - start_index + stride - 1) / stride); - auto limit_expr = - limit_exprs.empty() ? DynExpr::_(limit_index) : limit_exprs[dimension]; - auto start_expr = - start_exprs.empty() ? DynExpr::_(start_index) : start_exprs[dimension]; + DExpr limit_expr = + limit_exprs.empty() ? DExpr::Const(limit_index) : limit_exprs[dimension]; + DExpr start_expr = + start_exprs.empty() ? DExpr::Const(start_index) : start_exprs[dimension]; - auto new_expr = (*(*(*limit_expr - *start_expr) + stride) - 1)->s(); - expressions.push_back((*new_expr/stride)->s()); + auto new_expr = + (limit_expr - start_expr + DExpr::Const(stride) - DExpr::Const(1)) + .simplify(); + expressions.push_back((new_expr / DExpr::Const(stride)).simplify()); } std::vector is_dynamic(arg.dimensions_size()); @@ -3328,7 +3329,7 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { /* static */ absl::StatusOr ShapeInference::InferDynamicSliceShape( const Shape& operand_shape, absl::Span start_index_shapes, absl::Span slice_sizes, - absl::Span slice_exprs, bool allow_scalar_indices) { + absl::Span slice_exprs, bool allow_scalar_indices) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice")); auto number_of_indices = start_index_shapes.size(); // TODO(b/118437727): Remove this path. @@ -3728,7 +3729,7 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { /* static */ absl::StatusOr ShapeInference::InferBroadcastShape( const Shape& operand, absl::Span broadcast_sizes, - absl::Span broadcast_exprs) { + absl::Span broadcast_exprs) { // This method is used to infer shape for xla::BroadcastInDim. TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast")); TF_RET_CHECK(!operand.is_unbounded_dynamic()); @@ -3820,7 +3821,7 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { const Shape& operand, absl::Span dim_size_shapes, absl::Span new_size_bounds, const std::vector& dims_are_dynamic, - absl::Span expressions) { + absl::Span expressions) { if (new_size_bounds.size() != dims_are_dynamic.size()) { return InvalidArgument( "DynamicReshape has to have the same number of elements in new_sizes " @@ -3851,13 +3852,13 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { /* static */ absl::StatusOr ShapeInference::InferReshapeShape( const Shape& operand, absl::Span dimensions, - int64_t inferred_dimension, absl::Span expressions) { + int64_t inferred_dimension, absl::Span expressions) { TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape")); Shape inferred_shape = ShapeUtil::MakeShape(operand.element_type(), dimensions, expressions); if (expressions.empty() && operand.expressions().size() > 0 && - operand.expressions(0) != nullptr && operand.expressions(0)->is_dynamic()) { + operand.expressions(0)->is_dynamic()) { return InvalidArgument("Expressions is empty but operand is dynamic"); } @@ -4345,7 +4346,7 @@ static absl::Status ValidateGatherDimensionNumbers( std::vector expanded_start_indices_shape; // Also tracks if an output dimension is dynamic. std::vector expanded_start_indices_shape_dynamic_dimensions; - std::vector expanded_start_indices_shape_expressions; + std::vector expanded_start_indices_shape_expressions; expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size()); expanded_start_indices_shape_dynamic_dimensions.reserve( start_indices_shape.dimensions_size()); @@ -4363,7 +4364,7 @@ static absl::Status ValidateGatherDimensionNumbers( gather_dim_numbers.index_vector_dim()) { expanded_start_indices_shape.push_back(1); expanded_start_indices_shape_dynamic_dimensions.push_back(false); - expanded_start_indices_shape_expressions.push_back(DynExpr::one); + expanded_start_indices_shape_expressions.push_back(DExpr::Const(1)); } TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers( @@ -4430,12 +4431,12 @@ static absl::Status ValidateGatherDimensionNumbers( output_dim_bounds.reserve(result_rank); std::vector output_dim_is_dynamic; - std::vector output_expressions; + std::vector output_expressions; output_dim_is_dynamic.reserve(result_rank); for (int64_t i = 0; i < result_rank; i++) { int64_t current_bound; bool dim_dynamic = false; - DynExpr* expression = DynExpr::_(-80); + DExpr expression = DExpr::Unknown(80); bool is_window_index = absl::c_binary_search(gather_dim_numbers.offset_dims(), i); if (is_window_index) { @@ -4459,7 +4460,7 @@ static absl::Status ValidateGatherDimensionNumbers( dim_dynamic = input_shape.is_dynamic_dimension(offset_dims_seen); expression = input_shape.expressions(offset_dims_seen); } else { - expression = DynExpr::_(slice_sizes[offset_dims_seen]); + expression = DExpr::Const(slice_sizes[offset_dims_seen]); } current_bound = slice_sizes[offset_dims_seen++]; } else { diff --git a/third_party/xla/xla/service/shape_inference.h b/third_party/xla/xla/service/shape_inference.h index 4f5d0f4b2e3efa..cb0ee9a20f83c6 100644 --- a/third_party/xla/xla/service/shape_inference.h +++ b/third_party/xla/xla/service/shape_inference.h @@ -246,15 +246,15 @@ class ShapeInference { static absl::StatusOr InferSliceShape( const Shape& arg, absl::Span starts, absl::Span limits, absl::Span strides, - absl::Span start_exprs = {}, - absl::Span limit_exprs = {}); + absl::Span start_exprs = {}, + absl::Span limit_exprs = {}); // Infers the shape produced by a dynamic slice operation of size specified // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. static absl::StatusOr InferDynamicSliceShape( const Shape& operand_shape, absl::Span start_index_shapes, absl::Span slice_sizes, - absl::Span slice_exprs = {}, + absl::Span slice_exprs = {}, bool allow_scalar_indices = true); // Infers the shape produced by a dynamic update slice operation based @@ -288,7 +288,7 @@ class ShapeInference { // Infers the shape produced by a broadcast operation. static absl::StatusOr InferBroadcastShape( const Shape& operand, absl::Span broadcast_sizes, - absl::Span broadcast_exprs = {}); + absl::Span broadcast_exprs = {}); // Checks whether the given parameters can form a broadcast. Returns the same // output_shape if it's legal. @@ -300,7 +300,7 @@ class ShapeInference { // its operand and the new dimension sizes specified. static absl::StatusOr InferReshapeShape( const Shape& operand, absl::Span dimensions, - int64_t inferred_dimension, absl::Span expressions = {}); + int64_t inferred_dimension, absl::Span expressions); // Infers the shape produced by a dynamic reshape operation from the element // type of its operand and the new dimension sizes specified. The result shape @@ -310,7 +310,7 @@ class ShapeInference { const Shape& operand, absl::Span dim_size_shapes, absl::Span new_size_bounds, const std::vector& dims_are_dynamic, - absl::Span expressions); + absl::Span expressions); // Infers the shape produced by a transpose operation from the element type of // its operand and its dimensions field. diff --git a/third_party/xla/xla/service/triangular_solve_expander.cc b/third_party/xla/xla/service/triangular_solve_expander.cc index ea587461b1e277..3e51ffd4de8439 100644 --- a/third_party/xla/xla/service/triangular_solve_expander.cc +++ b/third_party/xla/xla/service/triangular_solve_expander.cc @@ -121,10 +121,10 @@ XlaOp DiagonalBlocks(XlaOp a, int64_t block_size) { std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin()); last_blocks_dims.insert(last_blocks_dims.end() - 2, 1); auto shape_exprs = blocks_shape.expressions(); - auto last_blocks_exprs = std::vector(ndims); + auto last_blocks_exprs = std::vector(ndims); std::copy(shape_exprs.begin(), shape_exprs.end(), last_blocks_exprs.begin()); - last_blocks_exprs.insert(last_blocks_exprs.end() - 2, DynExpr::one); + last_blocks_exprs.insert(last_blocks_exprs.end() - 2, DExpr::Const(1)); last_blocks = Reshape(last_blocks, last_blocks_dims, last_blocks_exprs); // Concatenate with the other blocks if necessary diff --git a/third_party/xla/xla/shape.cc b/third_party/xla/xla/shape.cc index 962984b9225b6a..5e256a6303d6b1 100644 --- a/third_party/xla/xla/shape.cc +++ b/third_party/xla/xla/shape.cc @@ -41,82 +41,72 @@ limitations under the License. namespace xla { -DynExpr* ExprFromProto(const ExpressionProto& proto) { - switch (proto.node_type_case()) { - case ExpressionProto::kConstantValue: - return DynExpr::_(proto.constant_value()); - - case ExpressionProto::kVariableId: - return DynExpr::V(proto.variable_id()); - - case ExpressionProto::kAddNode: { - const auto& add = proto.add_node(); - return *ExprFromProto(add.lhs()) + *ExprFromProto(add.rhs()); - } - - case ExpressionProto::kSubNode: { - const auto& sub = proto.sub_node(); - return *ExprFromProto(sub.lhs()) - *ExprFromProto(sub.rhs()); - } - - case ExpressionProto::kMulNode: { - const auto& mul = proto.mul_node(); - return *ExprFromProto(mul.lhs()) * *ExprFromProto(mul.rhs()); - } - - case ExpressionProto::kDivNode: { - const auto& div = proto.div_node(); - return *ExprFromProto(div.lhs()) / *ExprFromProto(div.rhs()); - } - - case ExpressionProto::NODE_TYPE_NOT_SET: - default: - return nullptr; - } +const DExpr& Shape::MissingExpression() { + static const DExpr missing = DExpr::Unknown(); + return missing; } -DynExpr* operator*(DynExpr& lhs, DynExpr& rhs) { return new Mul(&lhs, &rhs); } +DynExpr* operator*(DynExpr& lhs, DynExpr& rhs) { + return new Mul(lhs.clone().release(), rhs.clone().release()); +} DynExpr* operator*(int64_t k, DynExpr& rhs) { - return new Mul(DynExpr::_(k), &rhs); + return new Mul(DynExpr::_(k), rhs.clone().release()); +} +DynExpr* operator/(DynExpr& lhs, DynExpr& rhs) { + return new Div(lhs.clone().release(), rhs.clone().release()); } -DynExpr* operator/(DynExpr& lhs, DynExpr& rhs) { return new Div(&lhs, &rhs); } DynExpr* operator/(DynExpr& lhs, int64_t d) { - return new Div(&lhs, DynExpr::_(d)); + return new Div(lhs.clone().release(), DynExpr::_(d)); +} +DynExpr* operator+(DynExpr& lhs, DynExpr& rhs) { + return new Add(lhs.clone().release(), rhs.clone().release()); } -DynExpr* operator+(DynExpr& lhs, DynExpr& rhs) { return new Add(&lhs, &rhs); } DynExpr* operator+(DynExpr& lhs, int64_t d) { - return new Add(&lhs, DynExpr::_(d)); + return new Add(lhs.clone().release(), DynExpr::_(d)); +} +DynExpr* operator-(DynExpr& lhs, DynExpr& rhs) { + return new Sub(lhs.clone().release(), rhs.clone().release()); } -DynExpr* operator-(DynExpr& lhs, DynExpr& rhs) { return new Sub(&lhs, &rhs); } DynExpr* operator-(DynExpr& lhs, int64_t d) { - return new Sub(&lhs, DynExpr::_(d)); + return new Sub(lhs.clone().release(), DynExpr::_(d)); } bool operator==(DynExpr& lhs, DynExpr& rhs) { return DynExpr::equal(&lhs, &rhs); } bool operator==(DynExpr& lhs, int64_t d) { - return DynExpr::equal(&lhs, DynExpr::_(d)); + auto rhs = std::unique_ptr(DynExpr::_(d)); + return DynExpr::equal(&lhs, rhs.get()); } bool operator<(DynExpr& lhs, int64_t d) { return lhs.is_constant() && lhs.get_val() < d; } bool DynExpr::equal(DynExpr* expr1, DynExpr* expr2) { - auto e1 = expr1->s(); - auto e2 = expr2->s(); + auto e1 = std::unique_ptr(expr1->s()); + auto e2 = std::unique_ptr(expr2->s()); if (e1 == nullptr || e2 == nullptr) return false; - Constant* c1 = dynamic_cast(e1); - Constant* c2 = dynamic_cast(e2); - if (c1 && c2) return c1->get_val() == c2->get_val(); + if (e1->kind() == DExpr::Kind::kConstant && + e2->kind() == DExpr::Kind::kConstant) { + return static_cast(e1.get())->get_val() == + static_cast(e2.get())->get_val(); + } // Var x = Var y <=> x = y - if (Variable* varx = dynamic_cast(e1), - *vary = dynamic_cast(e2); - varx && vary) { - return varx->get_id() == vary->get_id(); + if (e1->kind() == DExpr::Kind::kVariable && + e2->kind() == DExpr::Kind::kVariable) { + return static_cast(e1.get())->get_id() == + static_cast(e2.get())->get_id(); + } + if (e1->kind() == DExpr::Kind::kUnknown && + e2->kind() == DExpr::Kind::kUnknown) { + int lhs_id = static_cast(e1.get())->get_id(); + int rhs_id = static_cast(e2.get())->get_id(); + return lhs_id != 0 && lhs_id == rhs_id; } // a * b = c * d <=> (a = c /\ b = d) \/ (a = d /\ b = c) - if (Mul* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); - ab && cd) { + if (e1->kind() == DExpr::Kind::kMul && + e2->kind() == DExpr::Kind::kMul) { + auto* ab = static_cast(e1.get()); + auto* cd = static_cast(e2.get()); auto a = ab->get_lhs(); auto b = ab->get_rhs(); auto c = cd->get_lhs(); @@ -124,8 +114,10 @@ bool DynExpr::equal(DynExpr* expr1, DynExpr* expr2) { return (*a == *c && *b == *d) || (*a == *d && *b == *c); } // a / b = c / d <=> (a = c /\ b = d) - if (Div* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); - ab && cd) { + if (e1->kind() == DExpr::Kind::kDiv && + e2->kind() == DExpr::Kind::kDiv) { + auto* ab = static_cast(e1.get()); + auto* cd = static_cast(e2.get()); auto a = ab->get_lhs(); auto b = ab->get_rhs(); auto c = cd->get_lhs(); @@ -133,8 +125,10 @@ bool DynExpr::equal(DynExpr* expr1, DynExpr* expr2) { return *a == *c && *b == *d; } // a + b = c + d <=> (a = c /\ b = d) \/ (a = d /\ b = c) - if (Add* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); - ab && cd) { + if (e1->kind() == DExpr::Kind::kAdd && + e2->kind() == DExpr::Kind::kAdd) { + auto* ab = static_cast(e1.get()); + auto* cd = static_cast(e2.get()); auto a = ab->get_lhs(); auto b = ab->get_rhs(); auto c = cd->get_lhs(); @@ -142,8 +136,10 @@ bool DynExpr::equal(DynExpr* expr1, DynExpr* expr2) { return (*a == *c && *b == *d) || (*a == *d && *b == *c); } // a - b = c - d <=> (a = c /\ b = d) - if (Sub* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); - ab && cd) { + if (e1->kind() == DExpr::Kind::kSub && + e2->kind() == DExpr::Kind::kSub) { + auto* ab = static_cast(e1.get()); + auto* cd = static_cast(e2.get()); auto* a = ab->get_lhs(); auto* b = ab->get_rhs(); auto* c = cd->get_lhs(); @@ -154,176 +150,250 @@ bool DynExpr::equal(DynExpr* expr1, DynExpr* expr2) { } // Simplification methods -DynExpr* Constant::s() { return this; } +DynExpr* Constant::s() { return clone().release(); } -DynExpr* Variable::s() { return this; } +DynExpr* Variable::s() { return clone().release(); } DynExpr* Mul::s() { - DynExpr* s_lhs = get_lhs()->s(); - DynExpr* s_rhs = get_rhs()->s(); - Constant* l = dynamic_cast(s_lhs); - Constant* r = dynamic_cast(s_rhs); + auto s_lhs = std::unique_ptr(get_lhs()->s()); + auto s_rhs = std::unique_ptr(get_rhs()->s()); + if (s_lhs->kind() == DExpr::Kind::kUnknown || + s_rhs->kind() == DExpr::Kind::kUnknown) { + return DExpr::Unknown().release(); + } + Constant* l = s_lhs->kind() == DExpr::Kind::kConstant + ? static_cast(s_lhs.get()) + : nullptr; + Constant* r = s_rhs->kind() == DExpr::Kind::kConstant + ? static_cast(s_rhs.get()) + : nullptr; // constant * constant if (l && r) return DynExpr::_(l->get_val() * r->get_val()); // 0 * X = 0 - if (l && l->get_val() == 0) return DynExpr::zero; + if (l && l->get_val() == 0) return DynExpr::_(0); // 1 * X = X - if (l && l->get_val() == 1) return s_rhs; + if (l && l->get_val() == 1) return s_rhs.release(); // X * 1 = X - if (r && r->get_val() == 1) return s_lhs; + if (r && r->get_val() == 1) return s_lhs.release(); // X * constant = constant * X - if (r && s_lhs->is_dynamic()) return (r->get_val() * *s_lhs)->s(); + if (r && s_lhs->is_dynamic()) { + auto reordered = std::unique_ptr(r->get_val() * *s_lhs); + return reordered->s(); + } // m * (nX) = (m*n) * X - if (Mul* nX = dynamic_cast(s_rhs)) { + if (s_rhs->kind() == DExpr::Kind::kMul) { + auto* nX = static_cast(s_rhs.get()); DynExpr* X = nX->get_rhs(); - Constant* n = dynamic_cast(nX->get_lhs()); + Constant* n = nX->get_lhs()->kind() == DExpr::Kind::kConstant + ? static_cast(nX->get_lhs()) + : nullptr; if (l && n) { auto mn = l->get_val() * n->get_val(); - return (mn * *X)->s(); + auto folded = std::unique_ptr(mn * *X); + return folded->s(); } } return (*s_lhs) * (*s_rhs); } DynExpr* Add::s() { - DynExpr* s_lhs = get_lhs()->s(); - DynExpr* s_rhs = get_rhs()->s(); - Constant* l = dynamic_cast(s_lhs); - Constant* r = dynamic_cast(s_rhs); + auto s_lhs = std::unique_ptr(get_lhs()->s()); + auto s_rhs = std::unique_ptr(get_rhs()->s()); + if (s_lhs->kind() == DExpr::Kind::kUnknown || + s_rhs->kind() == DExpr::Kind::kUnknown) { + return DExpr::Unknown().release(); + } + Constant* l = s_lhs->kind() == DExpr::Kind::kConstant + ? static_cast(s_lhs.get()) + : nullptr; + Constant* r = s_rhs->kind() == DExpr::Kind::kConstant + ? static_cast(s_rhs.get()) + : nullptr; // constant + constant if (l && r) return DynExpr::_(l->get_val() + r->get_val()); // 0 + X = X - if (l && l->get_val() == 0) return s_rhs; + if (l && l->get_val() == 0) return s_rhs.release(); // X + 0 = X - if (r && r->get_val() == 0) return s_lhs; + if (r && r->get_val() == 0) return s_lhs.release(); // m + X = X + m - if (l && s_rhs->is_dynamic()) return (*s_rhs + l->get_val())->s(); + if (l && s_rhs->is_dynamic()) { + auto reordered = std::unique_ptr(*s_rhs + l->get_val()); + return reordered->s(); + } // X + X = 2 * X if (*s_lhs == *s_rhs) { - return (2 * (*s_rhs))->s(); + auto doubled = std::unique_ptr(2 * (*s_rhs)); + return doubled->s(); } // nX + X = (n+1) * X - if (Mul* nX = dynamic_cast(s_lhs)) { + if (s_lhs->kind() == DExpr::Kind::kMul) { + auto* nX = static_cast(s_lhs.get()); DynExpr* n = nX->get_lhs(); DynExpr* X = nX->get_rhs(); if (*X == *s_rhs) { - return (*(*n + 1) * (*X))->s(); + auto incremented = std::unique_ptr(*n + 1); + auto combined = + std::unique_ptr(*incremented * (*X)); + return combined->s(); } } // X + nX = (n+1) * X - if (Mul* nX = dynamic_cast(s_rhs)) { + if (s_rhs->kind() == DExpr::Kind::kMul) { + auto* nX = static_cast(s_rhs.get()); DynExpr* n = nX->get_lhs(); DynExpr* X = nX->get_rhs(); if (*X == *s_lhs) { - return (*(*n + 1) * (*X))->s(); + auto incremented = std::unique_ptr(*n + 1); + auto combined = + std::unique_ptr(*incremented * (*X)); + return combined->s(); } } // mX + nX = (m+n) * X - if (Mul* mX = dynamic_cast(s_lhs), *nY = dynamic_cast(s_rhs); - mX && nY) { + if (s_lhs->kind() == DExpr::Kind::kMul && + s_rhs->kind() == DExpr::Kind::kMul) { + auto* mX = static_cast(s_lhs.get()); + auto* nY = static_cast(s_rhs.get()); DynExpr* m = mX->get_lhs(); DynExpr* X = mX->get_rhs(); DynExpr* n = nY->get_lhs(); DynExpr* Y = nY->get_rhs(); if (*X == *Y) { - return (*(*m + *n) * (*X))->s(); + auto summed = std::unique_ptr(*m + *n); + auto combined = std::unique_ptr(*summed * (*X)); + return combined->s(); } } // (X + Y) + Z = X + (Y + Z) - if (Add* XY = dynamic_cast(s_lhs)) { + if (s_lhs->kind() == DExpr::Kind::kAdd) { + auto* XY = static_cast(s_lhs.get()); DynExpr* X = XY->get_lhs(); DynExpr* Y = XY->get_rhs(); - return (*X + *(*Y + *s_rhs))->s(); + auto inner = std::unique_ptr(*Y + *s_rhs); + auto reassoc = std::unique_ptr(*X + *inner); + return reassoc->s(); } // (X - Y) + Z = X - (Y - Z) - if (Sub* XY = dynamic_cast(s_lhs)) { + if (s_lhs->kind() == DExpr::Kind::kSub) { + auto* XY = static_cast(s_lhs.get()); DynExpr* X = XY->get_lhs(); DynExpr* Y = XY->get_rhs(); - return (*X - *(*Y - *s_rhs))->s(); + auto inner = std::unique_ptr(*Y - *s_rhs); + auto reassoc = std::unique_ptr(*X - *inner); + return reassoc->s(); } return *s_lhs + *s_rhs; } DynExpr* Sub::s() { - if (!get_lhs()){ - LOG(INFO) << "NO LEFT"; + auto s_lhs = std::unique_ptr(get_lhs()->s()); + auto s_rhs = std::unique_ptr(get_rhs()->s()); + if (s_lhs->kind() == DExpr::Kind::kUnknown || + s_rhs->kind() == DExpr::Kind::kUnknown) { + return DExpr::Unknown().release(); } - - if (!get_rhs()){ - LOG(INFO) << "NO RIGHT"; - } - DynExpr* s_lhs = get_lhs()->s(); - DynExpr* s_rhs = get_rhs()->s(); - Constant* l = dynamic_cast(s_lhs); - Constant* r = dynamic_cast(s_rhs); + Constant* l = s_lhs->kind() == DExpr::Kind::kConstant + ? static_cast(s_lhs.get()) + : nullptr; + Constant* r = s_rhs->kind() == DExpr::Kind::kConstant + ? static_cast(s_rhs.get()) + : nullptr; // constant - constant if (l && r) return DynExpr::_(l->get_val() - r->get_val()); // X - 0 = X - if (r && r->get_val() == 0) return s_lhs; + if (r && r->get_val() == 0) return s_lhs.release(); // X - X = 0 if (*s_lhs == *s_rhs) { - return DynExpr::zero; + return DynExpr::_(0); } // mX - nX = (m-n) * X - if (Mul* mX = dynamic_cast(s_lhs), *nY = dynamic_cast(s_rhs); - mX && nY) { + if (s_lhs->kind() == DExpr::Kind::kMul && + s_rhs->kind() == DExpr::Kind::kMul) { + auto* mX = static_cast(s_lhs.get()); + auto* nY = static_cast(s_rhs.get()); DynExpr* m = mX->get_lhs(); DynExpr* X = mX->get_rhs(); DynExpr* n = nY->get_lhs(); DynExpr* Y = nY->get_rhs(); if (*X == *Y) { - return (*(*m - *n) * (*X))->s(); + auto diffed = std::unique_ptr(*m - *n); + auto combined = std::unique_ptr(*diffed * (*X)); + return combined->s(); } } // (X + Y) - X = X + (Y - Z) - if (Add* XY = dynamic_cast(s_lhs)) { + if (s_lhs->kind() == DExpr::Kind::kAdd) { + auto* XY = static_cast(s_lhs.get()); DynExpr* X = XY->get_lhs(); DynExpr* Y = XY->get_rhs(); - return (*X + *(*Y - *s_rhs))->s(); + auto inner = std::unique_ptr(*Y - *s_rhs); + auto reassoc = std::unique_ptr(*X + *inner); + return reassoc->s(); } // (X - Y) - Z = X - (Y + Z) - if (Sub* XY = dynamic_cast(s_lhs)) { + if (s_lhs->kind() == DExpr::Kind::kSub) { + auto* XY = static_cast(s_lhs.get()); DynExpr* X = XY->get_lhs(); DynExpr* Y = XY->get_rhs(); - return (*X - *(*Y + *s_rhs))->s(); + auto inner = std::unique_ptr(*Y + *s_rhs); + auto reassoc = std::unique_ptr(*X - *inner); + return reassoc->s(); } return *s_lhs - *s_rhs; } DynExpr* Div::s() { - DynExpr* s_lhs = get_lhs()->s(); - DynExpr* s_rhs = get_rhs()->s(); - Constant* l = dynamic_cast(s_lhs); - Constant* r = dynamic_cast(s_rhs); + auto s_lhs = std::unique_ptr(get_lhs()->s()); + auto s_rhs = std::unique_ptr(get_rhs()->s()); + if (s_lhs->kind() == DExpr::Kind::kUnknown || + s_rhs->kind() == DExpr::Kind::kUnknown) { + return DExpr::Unknown().release(); + } + Constant* l = s_lhs->kind() == DExpr::Kind::kConstant + ? static_cast(s_lhs.get()) + : nullptr; + Constant* r = s_rhs->kind() == DExpr::Kind::kConstant + ? static_cast(s_rhs.get()) + : nullptr; // constant / constant if (l && r) return DynExpr::_(l->get_val() / r->get_val()); // X / 1 = X - if (r && r->get_val() == 1) return s_lhs; + if (r && r->get_val() == 1) return s_lhs.release(); // (X + Y) / Z = (X/Z) + (Y/Z) - if (Add* XY = dynamic_cast(s_lhs)) { + if (s_lhs->kind() == DExpr::Kind::kAdd) { + auto* XY = static_cast(s_lhs.get()); DynExpr* X = XY->get_lhs(); DynExpr* Y = XY->get_rhs(); - return (*((*X) / (*s_rhs)) + *((*Y) / (*s_rhs)))->s(); + auto left = std::unique_ptr(*X / *s_rhs); + auto right = std::unique_ptr(*Y / *s_rhs); + auto distributed = std::unique_ptr(*left + *right); + return distributed->s(); } // (X * Y) / Z = (X/Z) * Y - if (Mul* XY = dynamic_cast(s_lhs)) { + if (s_lhs->kind() == DExpr::Kind::kMul) { + auto* XY = static_cast(s_lhs.get()); DynExpr* X = XY->get_lhs(); DynExpr* Y = XY->get_rhs(); - return (*(*X / (*s_rhs)) * (*Y))->s(); + auto left = std::unique_ptr(*X / *s_rhs); + auto distributed = std::unique_ptr(*left * (*Y)); + return distributed->s(); } // (X / Y) / Z = X / (Y*Z) - if (Div* XY = dynamic_cast(s_lhs)) { + if (s_lhs->kind() == DExpr::Kind::kDiv) { + auto* XY = static_cast(s_lhs.get()); DynExpr* X = XY->get_lhs(); DynExpr* Y = XY->get_rhs(); - return (*X / *(*Y * *s_rhs))->s(); + auto inner = std::unique_ptr(*Y * *s_rhs); + auto reassoc = std::unique_ptr(*X / *inner); + return reassoc->s(); } return *s_lhs / *s_rhs; } std::ostream& operator<<(std::ostream& os, DynExpr* expr) { - ExpressionProto proto; - expr->to_proto(&proto); - os << proto.ShortDebugString(); + StringPrinter printer; + expr->print(&printer); + os << std::move(printer).ToString(); return os; } @@ -419,9 +489,9 @@ absl::StatusOr Shape::FromProto(const ShapeProto& shape_proto) { // UnsafeAddDimension. We expect that the caller will eventually call a // validation routine that will detect the error in case the dimension // value is invalid. - DynExpr* expression = (i < num_expressions) - ? ExprFromProto(shape_proto.expressions(i)) - : DynExpr::_(shape_proto.dimensions(i)); + DExpr expression = + (i < num_expressions) ? DExprFromProto(shape_proto.expressions(i)) + : DExpr::Const(shape_proto.dimensions(i)); shape.UnsafeAddDimension(shape_proto.dimensions(i), is_dynamic, expression); } @@ -459,9 +529,9 @@ ShapeProto Shape::ToProto() const { for (const bool dynamic : state->dynamic_dimensions) { proto.add_is_dynamic_dimension(dynamic); } - for (const DynExpr* e : state->expressions) { + for (const DExpr& e : state->expressions) { ExpressionProto* eproto = proto.add_expressions(); - CHECK(e != nullptr) << "Missing expression in expression list."; + CHECK(e.get() != nullptr) << "Missing expression in expression list."; e->to_proto(eproto); } if (state->layout.has_value()) { @@ -559,15 +629,16 @@ bool Shape::AreAllLeavesIntegers() const { return primitive_util::IsIntegralType(element_type()); } -void Shape::add_dimensions(int64_t value, bool is_dynamic, DynExpr* expr) { +void Shape::add_dimensions(int64_t value, bool is_dynamic, DExpr expr) { if (value < 0) { CHECK(is_dynamic) << "static dimension must have size >= 0 instead of " << value << "."; CHECK_EQ(value, kUnboundedSize) << "dynamic dimension must have size == kUnboundedSize or >= 0."; } - UnsafeAddDimension(value, is_dynamic, - expr != nullptr ? expr : DynExpr::_(value)); + UnsafeAddDimension( + value, is_dynamic, + expr ? std::move(expr) : DExpr::Const(value)); } void Shape::set_dynamic_dimension(int dimension, bool is_dynamic) { @@ -577,20 +648,23 @@ void Shape::set_dynamic_dimension(int dimension, bool is_dynamic) { state.dynamic_dimensions[dimension] = is_dynamic; } -void Shape::set_expression(int dimension, DynExpr* e) { +void Shape::set_expression(int dimension, DExpr e) { auto& state = array_state(); state.expressions[dimension] = - e != nullptr ? e : DynExpr::_(state.dimensions[dimension]); + e ? std::move(e) : DExpr::Const(state.dimensions[dimension]); } -void Shape::set_expressions(std::vector exps) { +void Shape::set_expressions(std::vector exps) { auto& state = array_state(); CHECK_LE(exps.size(), state.dimensions.size()); state.expressions.resize(state.dimensions.size()); for (size_t i = 0; i < state.dimensions.size(); ++i) { - DynExpr* expr = i < exps.size() ? exps[i] : DynExpr::_(state.dimensions[i]); - state.expressions[i] = - expr != nullptr ? expr : DynExpr::_(state.dimensions[i]); + state.expressions[i] = i < exps.size() + ? std::move(exps[i]) + : DExpr::Const(state.dimensions[i]); + if (!state.expressions[i]) { + state.expressions[i] = DExpr::Const(state.dimensions[i]); + } } } @@ -602,7 +676,7 @@ void Shape::set_dimensions(int index, int64_t size, CheckDimensionSize(index, size, dynamic); state.dimensions[index] = size; state.dynamic_dimensions[index] = dynamic; - state.expressions[index] = DynExpr::_(size); + state.expressions[index] = DExpr::Const(size); } void Shape::set_dimensions_minor(int index, int64_t size, @@ -624,7 +698,7 @@ void Shape::CheckDimensionSize(int dim_index, int64_t size, bool is_dynamic) { } } -void Shape::UnsafeAddDimension(int64_t value, bool is_dynamic, DynExpr* exp) { +void Shape::UnsafeAddDimension(int64_t value, bool is_dynamic, DExpr exp) { auto& state = array_state(); CHECK_EQ(state.dimensions.size(), state.dynamic_dimensions.size()) << "where the shape is " << ToString(); @@ -632,7 +706,7 @@ void Shape::UnsafeAddDimension(int64_t value, bool is_dynamic, DynExpr* exp) { << "where the shape is " << ToString(); state.dimensions.push_back(value); state.dynamic_dimensions.push_back(is_dynamic); - state.expressions.push_back(exp != nullptr ? exp : DynExpr::_(value)); + state.expressions.push_back(exp ? std::move(exp) : DExpr::Const(value)); } bool Shape::is_static() const { diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index 1c5197ad838a88..c4adfbef2cd563 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -222,8 +222,8 @@ class Shape { bool has_dynamic_expr() const { if (auto* const state = if_array_state()) { return absl::c_any_of(state->expressions, - [](DynExpr* e) { - return e != nullptr && e->is_dynamic(); + [](const DExpr& e) { + return e && e->is_dynamic(); }); } if (auto* const state = if_tuple_state()) { @@ -234,12 +234,12 @@ class Shape { return false; } - DynExpr* expressions(int dimension) const { - if (dimension < 0) return DynExpr::_(-999); + const DExpr& expressions(int dimension) const { + if (dimension < 0) return MissingExpression(); const auto& exprs = array_state().expressions; const size_t dim = static_cast(dimension); - if (dim >= exprs.size()) return DynExpr::_(-999); - return exprs[dim] != nullptr ? exprs[dim] : DynExpr::_(-999); + if (dim >= exprs.size()) return MissingExpression(); + return exprs[dim] ? exprs[dim] : MissingExpression(); } // Returns true if the given dimension is statically-sized. @@ -256,9 +256,9 @@ class Shape { // - The dimension's size is valid for the given dynamic-ness. void set_dynamic_dimension(int dimension, bool is_dynamic); - void set_expression(int dimension, DynExpr* e); + void set_expression(int dimension, DExpr e); - void set_expressions(std::vector exprs); + void set_expressions(std::vector exprs); // Returns a span to indicate whether each dimension is dynamic. // Precondition: this is an array shape. @@ -266,9 +266,7 @@ class Shape { return array_state().dynamic_dimensions; } - absl::Span expressions() const { - return array_state().expressions; - } + absl::Span expressions() const { return array_state().expressions; } // Removes the given dimension from the shape. Layout, if it exists, is // adjusted to match the modified shape. @@ -346,7 +344,7 @@ class Shape { // - Either `value` is >= 0, or `is_dynamic` is true and `value` is // kUnboundedSize. void add_dimensions(int64_t value, bool is_dynamic = false, - xla::DynExpr* expr = nullptr); + DExpr expr = DExpr()); // Clears all dimensions (i.e. makes this shape a scalar). // Precondition: this is an array shape. @@ -603,10 +601,27 @@ class Shape { // respective dimension is dynamically sized. absl::InlinedVector dynamic_dimensions; - absl::InlinedVector expressions; + absl::InlinedVector expressions; // The layout of the shape. std::optional layout; + + ArrayState() = default; + ArrayState(const ArrayState& other) + : dimensions(other.dimensions), + dynamic_dimensions(other.dynamic_dimensions), + expressions(other.expressions), + layout(other.layout) {} + ArrayState& operator=(const ArrayState& other) { + if (this == &other) return *this; + dimensions = other.dimensions; + dynamic_dimensions = other.dynamic_dimensions; + expressions = other.expressions; + layout = other.layout; + return *this; + } + ArrayState(ArrayState&&) noexcept = default; + ArrayState& operator=(ArrayState&&) noexcept = default; }; struct TupleState { // The tuple element subshapes. @@ -623,13 +638,13 @@ class Shape { // CHECKs that the dimension size is valid. void CheckDimensionSize(int dim_index, int64_t size, bool is_dynamic); + static const DExpr& MissingExpression(); // Like add_dimensions(), but does not CHECK that the arguments are valid. // Instead, we rely on validation down the road to catch invalid shapes. // This is useful for code that should not crash, such as constructing a // Shape from an unvalidated proto. - void UnsafeAddDimension(int64_t value, bool is_dynamic, - DynExpr* exp = nullptr); + void UnsafeAddDimension(int64_t value, bool is_dynamic, DExpr exp); // Convenience accessors for the state_ variant. Each if_*_state() accessor // returns a pointer to the corresponding state struct, or nullptr if the diff --git a/third_party/xla/xla/shape_dynexpr.h b/third_party/xla/xla/shape_dynexpr.h index 890eed423df38c..a5713d92fbfc2b 100644 --- a/third_party/xla/xla/shape_dynexpr.h +++ b/third_party/xla/xla/shape_dynexpr.h @@ -17,18 +17,35 @@ limitations under the License. #define XLA_SHAPE_DYNEXPR_H_ #include +#include #include -#include #include +#include +#include +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/types/span.h" #include "xla/printer.h" #include "xla/xla_data.pb.h" namespace xla { +enum class DExprKind { + kUnknown, + kConstant, + kVariable, + kAdd, + kSub, + kMul, + kDiv, +}; + class DynExpr { public: virtual ~DynExpr() = default; + virtual std::unique_ptr clone() const = 0; + virtual DExprKind kind() const = 0; virtual void print(xla::Printer* printer) const = 0; virtual void to_proto(xla::ExpressionProto* proto) const = 0; virtual bool is_constant() const = 0; @@ -50,12 +67,136 @@ class DynExpr { friend std::ostream& operator<<(std::ostream& os, DynExpr* expr); }; +class DExpr { + public: + using Kind = DExprKind; + + DExpr() = default; + explicit DExpr(std::unique_ptr expr) : expr_(std::move(expr)) {} + + DExpr(const DExpr& other) { + if (other.expr_ != nullptr) { + expr_ = other.expr_->clone(); + } + } + DExpr& operator=(const DExpr& other) { + if (this == &other) return *this; + expr_.reset(); + if (other.expr_ != nullptr) { + expr_ = other.expr_->clone(); + } + return *this; + } + + DExpr(DExpr&&) noexcept = default; + DExpr& operator=(DExpr&&) noexcept = default; + + static DExpr Unknown(int id = 0); + static DExpr Adopt(DynExpr* expr) { return DExpr(std::unique_ptr(expr)); } + static DExpr Const(int64_t value) { return Adopt(DynExpr::_(value)); } + static DExpr Var(int var_id) { return Adopt(DynExpr::V(var_id)); } + bool is_unknown() const { + return expr_ != nullptr && expr_->kind() == DExprKind::kUnknown; + } + Kind kind() const { + CHECK(expr_ != nullptr) << "Attempted to access empty DExpr"; + return expr_->kind(); + } + + DynExpr* get() const { + CHECK(expr_ != nullptr) << "Attempted to access empty DExpr"; + return expr_.get(); + } + DynExpr& operator*() const { + CHECK(expr_ != nullptr) << "Attempted to dereference empty DExpr"; + return *expr_; + } + DynExpr* operator->() const { + CHECK(expr_ != nullptr) << "Attempted to access empty DExpr"; + return expr_.get(); + } + operator DynExpr*() const { return get(); } + explicit operator bool() const { return expr_ != nullptr && !is_unknown(); } + + std::unique_ptr clone() const { + if (expr_ == nullptr) { + return nullptr; + } + return expr_->clone(); + } + DynExpr* release() { return expr_.release(); } + + DExpr simplify() const { + return expr_ == nullptr ? DExpr() : Adopt(expr_->s()); + } + void to_proto(xla::ExpressionProto* proto) const { + CHECK(expr_ != nullptr) << "Attempted to serialize empty DExpr"; + expr_->to_proto(proto); + } + DExpr substitute(int id, const DExpr& value) const { + return expr_ == nullptr ? DExpr() : Adopt(expr_->substitute(id, value.get())); + } + + template + friend H AbslHashValue(H h, const DExpr& expr) { + xla::ExpressionProto proto; + if (expr.expr_ != nullptr) { + expr.expr_->to_proto(&proto); + } + return H::combine(std::move(h), proto.SerializeAsString()); + } + + private: + std::unique_ptr expr_; +}; + +class UnknownExpr : public DynExpr { + int id_; + + public: + explicit UnknownExpr(int id = 0) : id_(id) {} + std::unique_ptr clone() const override { + return std::make_unique(id_); + } + DExprKind kind() const override { return DExprKind::kUnknown; } + void print(xla::Printer* printer) const override { + printer->Append("?"); + if (id_ != 0) { + printer->Append(id_); + } + } + void to_proto(xla::ExpressionProto* proto) const override { + (void)proto; + } + bool is_constant() const override { return true; } + int get_id() const { return id_; } + DynExpr* substitute(int id, DynExpr* v) override { + (void)id; + (void)v; + return clone().release(); + } + std::set get_all_ids() override { return {}; } + std::optional solve(int64_t x) override { + (void)x; + return std::nullopt; + } + DynExpr* s() override { return clone().release(); } +}; + +inline DExpr DExpr::Unknown(int id) { + return DExpr(std::unique_ptr(new xla::UnknownExpr(id))); +} + // constant i class Constant : public DynExpr { int64_t value; public: explicit Constant(int64_t v) : value(v) {} + std::unique_ptr clone() const override { + return std::make_unique(value); + } + DExprKind kind() const override { return DExprKind::kConstant; } void print(xla::Printer* printer) const override { if (value < 0) { printer->Append("("); @@ -70,7 +211,7 @@ class Constant : public DynExpr { } bool is_constant() const override { return true; } int64_t get_val() const override { return value; } - DynExpr* substitute(int id, DynExpr* v) { return this; } + DynExpr* substitute(int id, DynExpr* v) { return clone().release(); } std::set get_all_ids() { return {}; } std::optional solve(int64_t x) { return std::nullopt; } DynExpr* s() override; @@ -82,6 +223,10 @@ class Variable : public DynExpr { public: explicit Variable(int identifier) : id(identifier) {} + std::unique_ptr clone() const override { + return std::make_unique(id); + } + DExprKind kind() const override { return DExprKind::kVariable; } void print(xla::Printer* printer) const override { // printer->Append("(Var "); char letter = 'A' + (id - 1); @@ -93,7 +238,9 @@ class Variable : public DynExpr { } bool is_constant() const override { return false; } int get_id() const { return id; } - DynExpr* substitute(int id, DynExpr* v) { return get_id() == id ? v : this;} + DynExpr* substitute(int id, DynExpr* v) { + return get_id() == id ? v->clone().release() : clone().release(); + } std::set get_all_ids() { return {get_id()}; } std::optional solve(int64_t x) { return x; } DynExpr* s() override; @@ -101,11 +248,15 @@ class Variable : public DynExpr { // exp = exp + exp class Add : public DynExpr { - DynExpr* lhs; - DynExpr* rhs; + std::unique_ptr lhs; + std::unique_ptr rhs; public: - Add(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + Add(DynExpr* l, DynExpr* r) : lhs(l), rhs(r) {} + std::unique_ptr clone() const override { + return std::make_unique(lhs->clone().release(), rhs->clone().release()); + } + DExprKind kind() const override { return DExprKind::kAdd; } void print(xla::Printer* printer) const override { printer->Append("("); lhs->print(printer); @@ -114,8 +265,8 @@ class Add : public DynExpr { printer->Append(")"); } - DynExpr* get_lhs() const { return lhs; } - DynExpr* get_rhs() const { return rhs; } + DynExpr* get_lhs() const { return lhs.get(); } + DynExpr* get_rhs() const { return rhs.get(); } void to_proto(xla::ExpressionProto* proto) const override { auto* add_msg = proto->mutable_add_node(); @@ -156,19 +307,20 @@ class Add : public DynExpr { DynExpr* s() override; - ~Add() { - delete lhs; - delete rhs; - } + ~Add() override = default; }; // exp = exp - exp class Sub : public DynExpr { - DynExpr* lhs; - DynExpr* rhs; + std::unique_ptr lhs; + std::unique_ptr rhs; public: - Sub(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + Sub(DynExpr* l, DynExpr* r) : lhs(l), rhs(r) {} + std::unique_ptr clone() const override { + return std::make_unique(lhs->clone().release(), rhs->clone().release()); + } + DExprKind kind() const override { return DExprKind::kSub; } void print(xla::Printer* printer) const override { printer->Append("("); lhs->print(printer); @@ -177,8 +329,8 @@ class Sub : public DynExpr { printer->Append(")"); } - DynExpr* get_lhs() const { return lhs; } - DynExpr* get_rhs() const { return rhs; } + DynExpr* get_lhs() const { return lhs.get(); } + DynExpr* get_rhs() const { return rhs.get(); } void to_proto(xla::ExpressionProto* proto) const override { auto* sub_msg = proto->mutable_sub_node(); @@ -219,19 +371,20 @@ class Sub : public DynExpr { DynExpr* s() override; - ~Sub() { - delete lhs; - delete rhs; - } + ~Sub() override = default; }; // exp = exp * exp class Mul : public DynExpr { - DynExpr* lhs; - DynExpr* rhs; + std::unique_ptr lhs; + std::unique_ptr rhs; public: - Mul(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + Mul(DynExpr* l, DynExpr* r) : lhs(l), rhs(r) {} + std::unique_ptr clone() const override { + return std::make_unique(lhs->clone().release(), rhs->clone().release()); + } + DExprKind kind() const override { return DExprKind::kMul; } void print(xla::Printer* printer) const override { printer->Append("("); lhs->print(printer); @@ -240,8 +393,8 @@ class Mul : public DynExpr { printer->Append(")"); } - DynExpr* get_lhs() const { return lhs; } - DynExpr* get_rhs() const { return rhs; } + DynExpr* get_lhs() const { return lhs.get(); } + DynExpr* get_rhs() const { return rhs.get(); } void to_proto(xla::ExpressionProto* proto) const override { auto* mul_msg = proto->mutable_mul_node(); @@ -288,19 +441,20 @@ class Mul : public DynExpr { DynExpr* s() override; - ~Mul() { - delete lhs; - delete rhs; - } + ~Mul() override = default; }; // expr / expr class Div : public DynExpr { - DynExpr* lhs; - DynExpr* rhs; + std::unique_ptr lhs; + std::unique_ptr rhs; public: - Div(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + Div(DynExpr* l, DynExpr* r) : lhs(l), rhs(r) {} + std::unique_ptr clone() const override { + return std::make_unique
(lhs->clone().release(), rhs->clone().release()); + } + DExprKind kind() const override { return DExprKind::kDiv; } void print(xla::Printer* printer) const override { printer->Append("("); lhs->print(printer); @@ -309,8 +463,8 @@ class Div : public DynExpr { printer->Append(") )"); } - DynExpr* get_lhs() const { return lhs; } - DynExpr* get_rhs() const { return rhs; } + DynExpr* get_lhs() const { return lhs.get(); } + DynExpr* get_rhs() const { return rhs.get(); } void to_proto(xla::ExpressionProto* proto) const override { auto* div_msg = proto->mutable_div_node(); @@ -354,10 +508,7 @@ class Div : public DynExpr { return std::nullopt; } - ~Div() { - delete lhs; - delete rhs; - } + ~Div() override = default; }; DynExpr* operator*(DynExpr& lhs, DynExpr& rhs); @@ -371,9 +522,66 @@ DynExpr* operator-(DynExpr& lhs, int64_t d); bool operator==(DynExpr& lhs, DynExpr& rhs); bool operator==(DynExpr& lhs, int64_t d); +inline DExpr operator*(const DExpr& lhs, const DExpr& rhs) { + return DExpr::Adopt(*lhs.get() * *rhs.get()); +} +inline DExpr operator*(int64_t lhs, const DExpr& rhs) { + return DExpr::Adopt(lhs * *rhs.get()); +} +inline DExpr operator/(const DExpr& lhs, const DExpr& rhs) { + return DExpr::Adopt(*lhs.get() / *rhs.get()); +} +inline DExpr operator/(const DExpr& lhs, int64_t rhs) { + return DExpr::Adopt(*lhs.get() / rhs); +} +inline DExpr operator+(const DExpr& lhs, const DExpr& rhs) { + return DExpr::Adopt(*lhs.get() + *rhs.get()); +} +inline DExpr operator+(const DExpr& lhs, int64_t rhs) { + return DExpr::Adopt(*lhs.get() + rhs); +} +inline DExpr operator-(const DExpr& lhs, const DExpr& rhs) { + return DExpr::Adopt(*lhs.get() - *rhs.get()); +} +inline DExpr operator-(const DExpr& lhs, int64_t rhs) { + return DExpr::Adopt(*lhs.get() - rhs); +} +inline bool operator==(const DExpr& lhs, const DExpr& rhs) { + return *lhs.get() == *rhs.get(); +} +inline bool operator==(const DExpr& lhs, int64_t rhs) { + return *lhs.get() == rhs; +} + +inline DExpr DExprFromProto(const xla::ExpressionProto& proto) { + switch (proto.node_type_case()) { + case ExpressionProto::kConstantValue: + return DExpr::Const(proto.constant_value()); + case ExpressionProto::kVariableId: + return DExpr::Var(proto.variable_id()); + case ExpressionProto::kAddNode: { + const auto& add = proto.add_node(); + return DExprFromProto(add.lhs()) + DExprFromProto(add.rhs()); + } + case ExpressionProto::kSubNode: { + const auto& sub = proto.sub_node(); + return DExprFromProto(sub.lhs()) - DExprFromProto(sub.rhs()); + } + case ExpressionProto::kMulNode: { + const auto& mul = proto.mul_node(); + return DExprFromProto(mul.lhs()) * DExprFromProto(mul.rhs()); + } + case ExpressionProto::kDivNode: { + const auto& div = proto.div_node(); + return DExprFromProto(div.lhs()) / DExprFromProto(div.rhs()); + } + case ExpressionProto::NODE_TYPE_NOT_SET: + default: + return DExpr::Unknown(); + } +} + inline DynExpr* DynExpr::_(int64_t val) { - if (val == 0) return DynExpr::zero; - if (val == 1) return DynExpr::one; return new Constant(val); } inline DynExpr* DynExpr::V(int var_id) { return new Variable(var_id); } diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index 2a72f66d943dfd..a33a0c31d1a447 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -267,19 +267,19 @@ static std::vector MakeDynamicDimensions( return dynamic_dimensions; } -static std::vector MakeExpressions( +static std::vector MakeExpressions( absl::Span dimensions) { - std::vector expressions; + std::vector expressions; expressions.reserve(dimensions.size()); for (int64_t d : dimensions) { - expressions.push_back(DynExpr::_(d)); + expressions.push_back(DExpr::Const(d)); } return expressions; } /* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type, absl::Span dimensions, - absl::Span expressions) { + absl::Span expressions) { return MakeValidatedShape(element_type, dimensions, expressions).value(); } @@ -290,7 +290,7 @@ static std::vector MakeExpressions( /* static */ Shape ShapeUtil::MakeShape( PrimitiveType element_type, absl::Span dimensions, const std::vector& dynamic_dimensions, - absl::Span expressions) { + absl::Span expressions) { return MakeValidatedShape(element_type, dimensions, dynamic_dimensions, expressions) .value(); @@ -310,16 +310,26 @@ static std::vector MakeExpressions( /* static */ absl::StatusOr ShapeUtil::MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions, - absl::Span expressions) { + absl::Span expressions) { + if (expressions.empty() && !dimensions.empty()) { + std::vector filled_expressions = MakeExpressions(dimensions); + return MakeValidatedShape(element_type, dimensions, + MakeDynamicDimensions(dimensions), + filled_expressions); + } return MakeValidatedShape( - element_type, dimensions, MakeDynamicDimensions(dimensions), - expressions.empty() ? MakeExpressions(dimensions) : expressions); + element_type, dimensions, MakeDynamicDimensions(dimensions), expressions); } /* static */ absl::StatusOr ShapeUtil::MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions, const std::vector& dynamic_dimensions, - absl::Span expressions) { + absl::Span expressions) { + std::vector filled_expressions; + if (expressions.empty() && !dimensions.empty()) { + filled_expressions = MakeExpressions(dimensions); + expressions = filled_expressions; + } if (dynamic_dimensions.size() != dimensions.size()) { return InvalidArgument( "dynamic dimensions size %d did not match number of dimensions %d", @@ -348,7 +358,7 @@ static std::vector MakeExpressions( for (int i = 0; i < ndims; i++) { const int64_t d = dimensions[i]; const bool is_dynamic = dynamic_dimensions[i]; - DynExpr* expression = expressions[i]; + const DExpr& expression = expressions[i]; if (!Shape::IsValidDimensionSize(d, is_dynamic)) { return InvalidArgument("Invalid dimension size %d, is_dynamic=%s", d, is_dynamic ? "true" : "false"); @@ -430,11 +440,12 @@ static std::vector MakeExpressions( /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout( PrimitiveType element_type, absl::Span dimensions, - absl::Span expressions) { - std::vector layout(dimensions.size()); - std::iota(layout.rbegin(), layout.rend(), static_cast(0)); - auto shape = MakeShapeWithDenseLayout(element_type, dimensions, layout); - std::vector exprs(expressions.begin(), expressions.end()); + absl::Span expressions) { + auto shape = MakeShapeWithDenseLayout(element_type, dimensions, + LayoutUtil::MakeDescendingLayout( + dimensions.size()) + .minor_to_major()); + std::vector exprs(expressions.begin(), expressions.end()); shape.set_expressions(exprs); return shape; } @@ -759,8 +770,8 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { } else { // Only print constant expression if it is different than the dimension // (i.e. it is wrong!) - DynExpr* expr = shape.expressions(i); - bool is_wrong = expr != nullptr && expr->is_constant() && + const DExpr& expr = shape.expressions(i); + bool is_wrong = expr && expr->is_constant() && expr->get_val() != shape.dimensions(i); printer->Append(shape.dimensions(i)); if (is_wrong) { @@ -773,7 +784,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { expr->print(printer); printer->Append("!>"); } - if (expr != nullptr && expr->is_dynamic()) { + if (expr && expr->is_dynamic()) { printer->Append("<"); expr->print(printer); printer->Append(">"); @@ -926,8 +937,8 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { return shape.dimensions(GetDimensionNumber(shape, dimension_number)); } -/* static */ DynExpr* ShapeUtil::GetExpression(const Shape& shape, - int64_t dimension_number) { +/* static */ const DExpr& ShapeUtil::GetExpression( + const Shape& shape, int64_t dimension_number) { return shape.expressions(GetDimensionNumber(shape, dimension_number)); } diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h index 7e9ed98719ff4f..c84678fa88c9f9 100644 --- a/third_party/xla/xla/shape_util.h +++ b/third_party/xla/xla/shape_util.h @@ -332,7 +332,8 @@ class ShapeUtil { // Extracts the shape's expressions at dimension number // GetDimensionNumber(dimension_number). - static DynExpr* GetExpression(const Shape& shape, int64_t dimension_number); + static const DExpr& GetExpression(const Shape& shape, + int64_t dimension_number); // Resolves a dimension number, supporting negative indexing. // @@ -411,7 +412,7 @@ class ShapeUtil { // dimensions. static Shape MakeShape(PrimitiveType element_type, absl::Span dimensions, - absl::Span expressions = {}); + absl::Span expressions = {}); // Make a scalar shape with given primitive type. static Shape MakeScalarShape(PrimitiveType element_type); @@ -425,7 +426,7 @@ class ShapeUtil { static Shape MakeShape(PrimitiveType element_type, absl::Span dimensions, const std::vector& dynamic_dimensions, - absl::Span expressions); + absl::Span expressions); // Constructs a new buffer shape with the given element type, and sequence of // dimensions. static Shape MakeBufferShape(PrimitiveType element_type, @@ -437,12 +438,12 @@ class ShapeUtil { // marked static. static absl::StatusOr MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions, - absl::Span expressions = {}); + absl::Span expressions = {}); static absl::StatusOr MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions, const std::vector& dynamic_dimensions, - absl::Span expressions = {}); + absl::Span expressions = {}); // Creates a Shape with element type corresponding to T and the given // dimensions @@ -486,7 +487,7 @@ class ShapeUtil { // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}). static Shape MakeShapeWithDescendingLayout( PrimitiveType element_type, absl::Span dimensions, - absl::Span expressions = {}); + absl::Span expressions = {}); // Returns a new Shape based on the given Shape with low-dimension-major // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions From c0e3c545cb18edf27121481e88eac16ea6cf941a Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Wed, 1 Apr 2026 17:35:05 +0100 Subject: [PATCH 2/5] Normalize missing TensorShape expressions to Unknown --- tensorflow/compiler/tf2xla/kernels/const_op.cc | 2 +- tensorflow/core/common_runtime/constant_folding.cc | 2 +- tensorflow/core/framework/tensor_shape.cc | 9 ++++++--- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index 0bb0e0958e18eb..faa16509eb2cc2 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -146,7 +146,7 @@ class ConstOp : public XlaOpKernel { if (has_dynamic) { std::vector dimension_constants; for (int i = 0; i < shape.dims(); ++i) { - if (shape.get_expression(i)->is_dynamic()) { + if (shape.get_expression(i) && shape.get_expression(i)->is_dynamic()) { int32_t dim_val = static_cast(shape.dim_size(i)); xla::XlaOp scalar_const = xla::ConstantR0(b, dim_val); xla::ExpressionProto expr_proto; diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index ceae157b5b7ef4..d8d5814ba79c84 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -454,7 +454,7 @@ bool GetShapeFromArgNode(const Node* node, TensorShapeProto* out_shape) { .ok() && !shapes.empty()) { for (auto expression : TensorShape(shapes[0]).get_expressions()) { - if (expression->is_dynamic()) { + if (expression && expression->is_dynamic()) { *out_shape = shapes[0]; return true; } diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc index 54013ddb603381..163fd8f9b7b826 100644 --- a/tensorflow/core/framework/tensor_shape.cc +++ b/tensorflow/core/framework/tensor_shape.cc @@ -485,9 +485,9 @@ void TensorShapeRep::set_expression(int d, xla::DExpr expr) { return; } if (expressions_.size() <= static_cast(d)) { - expressions_.resize(d + 1); + expressions_.resize(d + 1, xla::DExpr::Unknown()); } - expressions_[d] = std::move(expr); + expressions_[d] = expr ? std::move(expr) : xla::DExpr::Unknown(); } void TensorShapeRep::AddExpression(xla::DExpr expr) { @@ -495,7 +495,7 @@ void TensorShapeRep::AddExpression(xla::DExpr expr) { return; } CHECK_LT(expressions_.size(), ndims_byte()); - expressions_.push_back(std::move(expr)); + expressions_.push_back(expr ? std::move(expr) : xla::DExpr::Unknown()); } void TensorShapeRep::set_expressions(std::vector exprs) { @@ -503,6 +503,9 @@ void TensorShapeRep::set_expressions(std::vector exprs) { expressions_.clear(); return; } + for (auto& expr : exprs) { + if (!expr) expr = xla::DExpr::Unknown(); + } expressions_ = std::move(exprs); } From 6963256544ce883228329a43afc3202522d4cc2d Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Wed, 1 Apr 2026 17:40:47 +0100 Subject: [PATCH 3/5] Avoid TensorShape construction in GetShapeFromArgNode --- .../core/common_runtime/constant_folding.cc | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index d8d5814ba79c84..047343403d536b 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -48,6 +48,28 @@ namespace tensorflow { namespace { +bool IsDynamicExpressionProto(const ExpressionProto& proto) { + switch (proto.node_type_case()) { + case ExpressionProto::kVariableId: + return true; + case ExpressionProto::kAddNode: + return IsDynamicExpressionProto(proto.add_node().lhs()) || + IsDynamicExpressionProto(proto.add_node().rhs()); + case ExpressionProto::kSubNode: + return IsDynamicExpressionProto(proto.sub_node().lhs()) || + IsDynamicExpressionProto(proto.sub_node().rhs()); + case ExpressionProto::kMulNode: + return IsDynamicExpressionProto(proto.mul_node().lhs()) || + IsDynamicExpressionProto(proto.mul_node().rhs()); + case ExpressionProto::kDivNode: + return IsDynamicExpressionProto(proto.div_node().lhs()) || + IsDynamicExpressionProto(proto.div_node().rhs()); + case ExpressionProto::kConstantValue: + case ExpressionProto::NODE_TYPE_NOT_SET: + return false; + } +} + const char kScopedAllocatorAttrName[] = "_scoped_allocator"; const char kXlaShapeDerivedAttrName[] = "_xla_shape_derived"; @@ -453,8 +475,8 @@ bool GetShapeFromArgNode(const Node* node, TensorShapeProto* out_shape) { if (GetNodeAttr(input_node->def(), "_output_shapes", &shapes) .ok() && !shapes.empty()) { - for (auto expression : TensorShape(shapes[0]).get_expressions()) { - if (expression && expression->is_dynamic()) { + for (const auto& expression : shapes[0].expressions()) { + if (IsDynamicExpressionProto(expression)) { *out_shape = shapes[0]; return true; } From 35912ee84db2f3824cea328f719d76800bf42864 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Wed, 1 Apr 2026 17:45:51 +0100 Subject: [PATCH 4/5] Share dynamic TensorShape proto helper --- .../core/common_runtime/constant_folding.cc | 31 +++---------------- .../core/framework/tensor_shape_expr.cc | 31 +++++++++++++++++++ tensorflow/core/framework/tensor_shape_expr.h | 6 ++++ 3 files changed, 41 insertions(+), 27 deletions(-) diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 047343403d536b..a00dfce7a68998 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape_expr.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/algorithm.h" @@ -48,28 +49,6 @@ namespace tensorflow { namespace { -bool IsDynamicExpressionProto(const ExpressionProto& proto) { - switch (proto.node_type_case()) { - case ExpressionProto::kVariableId: - return true; - case ExpressionProto::kAddNode: - return IsDynamicExpressionProto(proto.add_node().lhs()) || - IsDynamicExpressionProto(proto.add_node().rhs()); - case ExpressionProto::kSubNode: - return IsDynamicExpressionProto(proto.sub_node().lhs()) || - IsDynamicExpressionProto(proto.sub_node().rhs()); - case ExpressionProto::kMulNode: - return IsDynamicExpressionProto(proto.mul_node().lhs()) || - IsDynamicExpressionProto(proto.mul_node().rhs()); - case ExpressionProto::kDivNode: - return IsDynamicExpressionProto(proto.div_node().lhs()) || - IsDynamicExpressionProto(proto.div_node().rhs()); - case ExpressionProto::kConstantValue: - case ExpressionProto::NODE_TYPE_NOT_SET: - return false; - } -} - const char kScopedAllocatorAttrName[] = "_scoped_allocator"; const char kXlaShapeDerivedAttrName[] = "_xla_shape_derived"; @@ -475,11 +454,9 @@ bool GetShapeFromArgNode(const Node* node, TensorShapeProto* out_shape) { if (GetNodeAttr(input_node->def(), "_output_shapes", &shapes) .ok() && !shapes.empty()) { - for (const auto& expression : shapes[0].expressions()) { - if (IsDynamicExpressionProto(expression)) { - *out_shape = shapes[0]; - return true; - } + if (HasDynamicDimExprs(shapes[0])) { + *out_shape = shapes[0]; + return true; } } } diff --git a/tensorflow/core/framework/tensor_shape_expr.cc b/tensorflow/core/framework/tensor_shape_expr.cc index d1e036f9b1d321..baca4df733d63f 100644 --- a/tensorflow/core/framework/tensor_shape_expr.cc +++ b/tensorflow/core/framework/tensor_shape_expr.cc @@ -26,6 +26,37 @@ bool TensorShapeExpressionsEnabled() { return enabled; } +bool IsDynamicDimExpr(const ExpressionProto& proto) { + switch (proto.node_type_case()) { + case ExpressionProto::kVariableId: + return true; + case ExpressionProto::kAddNode: + return IsDynamicDimExpr(proto.add_node().lhs()) || + IsDynamicDimExpr(proto.add_node().rhs()); + case ExpressionProto::kSubNode: + return IsDynamicDimExpr(proto.sub_node().lhs()) || + IsDynamicDimExpr(proto.sub_node().rhs()); + case ExpressionProto::kMulNode: + return IsDynamicDimExpr(proto.mul_node().lhs()) || + IsDynamicDimExpr(proto.mul_node().rhs()); + case ExpressionProto::kDivNode: + return IsDynamicDimExpr(proto.div_node().lhs()) || + IsDynamicDimExpr(proto.div_node().rhs()); + case ExpressionProto::kConstantValue: + case ExpressionProto::NODE_TYPE_NOT_SET: + return false; + } +} + +bool HasDynamicDimExprs(const TensorShapeProto& proto) { + for (const auto& expr : proto.expressions()) { + if (IsDynamicDimExpr(expr)) { + return true; + } + } + return false; +} + std::unique_ptr DimExpr::Cons(int64_t val) { return std::make_unique(val); } diff --git a/tensorflow/core/framework/tensor_shape_expr.h b/tensorflow/core/framework/tensor_shape_expr.h index 6dd4cafaf1adf8..1c215fda268dcf 100644 --- a/tensorflow/core/framework/tensor_shape_expr.h +++ b/tensorflow/core/framework/tensor_shape_expr.h @@ -218,6 +218,12 @@ DimExpr* SimplifyExpr(DimExpr* expr, // Shape-expression support follows the `tf_xla_enable_dynamic_sizes` flag. bool TensorShapeExpressionsEnabled(); +// Returns true if the expression proto depends on a symbolic variable. +bool IsDynamicDimExpr(const ExpressionProto& proto); + +// Returns true if any expression attached to the TensorShapeProto is dynamic. +bool HasDynamicDimExprs(const TensorShapeProto& proto); + } // namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_EXPR_H_ From 3b219efa5e25c453a80eeb344c6dfecab1d123ec Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Wed, 1 Apr 2026 17:46:19 +0100 Subject: [PATCH 5/5] Comment proto-only dynamic shape check --- tensorflow/core/common_runtime/constant_folding.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index a00dfce7a68998..151413d0f3bca1 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -454,6 +454,10 @@ bool GetShapeFromArgNode(const Node* node, TensorShapeProto* out_shape) { if (GetNodeAttr(input_node->def(), "_output_shapes", &shapes) .ok() && !shapes.empty()) { + // Stay on the proto here instead of rebuilding a TensorShape. These + // inferred shapes may still contain unknown (-1) dimensions, and the + // proto expressions are enough for deciding whether the value is + // dynamically derived. if (HasDynamicDimExprs(shapes[0])) { *out_shape = shapes[0]; return true;