diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index f9ee621bbe5a6f..d8a1387ee76a8c 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/jit/flags.h" #include "xla/service/shape_inference.h" #include "xla/shape.h" #include "xla/xla_data.pb.h" @@ -1144,12 +1145,15 @@ xla::Shape GetShape(shape_inference::ShapeHandle shape_handle, std::vector dims; std::vector dynamic_dims; std::vector expressions; + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); for (int i = 0, rank = c->Rank(shape_handle); i < rank; ++i) { bool is_dynamic = !c->ValueKnown(c->Dim(shape_handle, i)); int dynamic_multiplier = c->DynamicRatio(c->Dim(shape_handle, i)); dynamic_dims.push_back(is_dynamic); - expressions.push_back(xla::DExpr::Const(dynamic_multiplier) * - xla::DExpr::Var(1)); + if (flags->tf_xla_enable_dynamic_sizes) { + expressions.push_back(xla::DExpr::Const(dynamic_multiplier) * + xla::DExpr::Var(1)); + } dims.push_back(is_dynamic ? xla::Shape::kUnboundedSize : c->Value(c->Dim(shape_handle, i))); } @@ -1158,7 +1162,9 @@ xla::Shape GetShape(shape_inference::ShapeHandle shape_handle, xla::PrimitiveType::S64, dims, absl::InlinedVector(dynamic_dims.begin(), dynamic_dims.end())); - sh.set_expressions(expressions); + if (flags->tf_xla_enable_dynamic_sizes) { + sh.set_expressions(expressions); + } return sh; } diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index 7e04acf7ecb503..6aaa0419966e37 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "xla/layout_util.h" #include "xla/shape_util.h" @@ -100,9 +101,12 @@ absl::Status XLAShapeToTensorShape(const xla::Shape& shape, for (int i = 0; i < shape.dimensions().size(); ++i) { TF_RETURN_IF_ERROR(tensor_shape->AddDimWithStatus(shape.dimensions(i))); } - std::vector dexprs(shape.expressions().begin(), - shape.expressions().end()); - tensor_shape->set_expressions(std::move(dexprs)); + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + if (flags->tf_xla_enable_dynamic_sizes) { + std::vector dexprs(shape.expressions().begin(), + shape.expressions().end()); + tensor_shape->set_expressions(std::move(dexprs)); + } return absl::OkStatus(); } @@ -170,8 +174,12 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, } int rank = tensor_shape.dims(); std::vector dimensions(rank); - std::vector expressions(rank); std::vector layout(rank); + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + std::vector expressions; + if (flags->tf_xla_enable_dynamic_sizes) { + expressions.resize(rank); + } for (int d = 0; d < rank; ++d) { dimensions[d] = tensor_shape.dim_size(d); if (dimensions[d] < 0) { @@ -179,13 +187,17 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, "shape; returning unknown sentinel value"; return xla::ShapeUtil::MakeShapeWithDenseLayout(type, {0}, {0}); } - expressions[d] = tensor_shape.get_filled_expression(d); + if (flags->tf_xla_enable_dynamic_sizes) { + expressions[d] = tensor_shape.get_filled_expression(d); + } } // XLA uses minor-to-major; Tensorflow uses major-to-minor. std::iota(layout.rbegin(), layout.rend(), 0); xla::Shape result = xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout); - result.set_expressions(expressions); + if (flags->tf_xla_enable_dynamic_sizes) { + result.set_expressions(expressions); + } return result; } @@ -213,11 +225,17 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, int rank = tensor_shape.dims(); std::vector dimensions(rank); std::vector layout(rank); - std::vector expressions(rank); + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + std::vector expressions; + if (flags->tf_xla_enable_dynamic_sizes) { + expressions.resize(rank); + } for (int d = 0; d < rank; ++d) { dimensions[d] = tensor_shape.dim_size(d); - expressions[d] = tensor_shape.get_filled_expression(d); + if (flags->tf_xla_enable_dynamic_sizes) { + expressions[d] = tensor_shape.get_filled_expression(d); + } } // XLA uses minor-to-major; Tensorflow uses major-to-minor. @@ -225,7 +243,9 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, auto shape = xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout); - shape.set_expressions(expressions); + if (flags->tf_xla_enable_dynamic_sizes) { + shape.set_expressions(expressions); + } return shape; } diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc index 163fd8f9b7b826..095715a13e1a5d 100644 --- a/tensorflow/core/framework/tensor_shape.cc +++ b/tensorflow/core/framework/tensor_shape.cc @@ -56,7 +56,7 @@ xla::DExpr DExprFromProto(const ExpressionProto& proto) { } case ExpressionProto::NODE_TYPE_NOT_SET: default: - return xla::DExpr::Unknown(); + return xla::DExpr::Unknown(xla::kMissingExpressionSentinel); } } @@ -485,9 +485,11 @@ void TensorShapeRep::set_expression(int d, xla::DExpr expr) { return; } if (expressions_.size() <= static_cast(d)) { - expressions_.resize(d + 1, xla::DExpr::Unknown()); + expressions_.resize(d + 1, + xla::DExpr::Unknown(xla::kMissingExpressionSentinel)); } - expressions_[d] = expr ? std::move(expr) : xla::DExpr::Unknown(); + expressions_[d] = expr ? std::move(expr) + : xla::DExpr::Unknown(xla::kMissingExpressionSentinel); } void TensorShapeRep::AddExpression(xla::DExpr expr) { @@ -495,7 +497,9 @@ void TensorShapeRep::AddExpression(xla::DExpr expr) { return; } CHECK_LT(expressions_.size(), ndims_byte()); - expressions_.push_back(expr ? std::move(expr) : xla::DExpr::Unknown()); + expressions_.push_back(expr ? std::move(expr) + : xla::DExpr::Unknown( + xla::kMissingExpressionSentinel)); } void TensorShapeRep::set_expressions(std::vector exprs) { @@ -504,7 +508,7 @@ void TensorShapeRep::set_expressions(std::vector exprs) { return; } for (auto& expr : exprs) { - if (!expr) expr = xla::DExpr::Unknown(); + if (!expr) expr = xla::DExpr::Unknown(xla::kMissingExpressionSentinel); } expressions_ = std::move(exprs); } @@ -723,7 +727,13 @@ template void TensorShapeBase::set_dim(int d, int64_t size) { CHECK_GE(d, 0); CHECK_LT(d, dims()); - if (get_expressions().size() > d) set_expression(d, xla::DExpr::Const(size)); + // After DExpr migration, missing slots may be normalized to Unknown(). + // Preserve those placeholders here instead of materializing them into a + // concrete constant just because the dimension size changed. + if (get_expressions().size() > d && + get_expression(d).kind() != xla::DExpr::Kind::kUnknown) { + set_expression(d, xla::DExpr::Const(size)); + } if (!kIsPartial) { CHECK_GE(size, 0); } @@ -785,7 +795,13 @@ absl::Status TensorShapeBase::SetDimWithStatus(int d, int64_t size) { } } - if (get_expressions().size() > d) set_expression(d, xla::DExpr::Const(size)); + // After DExpr migration, missing slots may be normalized to Unknown(). + // Preserve those placeholders here instead of materializing them into a + // concrete constant just because the dimension size changed. + if (get_expressions().size() > d && + get_expression(d).kind() != xla::DExpr::Kind::kUnknown) { + set_expression(d, xla::DExpr::Const(size)); + } return RecomputeNumElements(); } diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h index 5507cfac636fa0..0ecab26ea045ae 100644 --- a/tensorflow/core/framework/tensor_shape.h +++ b/tensorflow/core/framework/tensor_shape.h @@ -103,7 +103,8 @@ class TensorShapeRep { // Return the multiplier for a specific dynamic dimension. // -1 if the dimension is not dynamic. const xla::DExpr& get_expression(int64_t dimension) const { - static const xla::DExpr kMissingExpression = xla::DExpr::Unknown(); + static const xla::DExpr kMissingExpression = + xla::DExpr::Unknown(xla::kMissingExpressionSentinel); if (dimension < 0) return kMissingExpression; const size_t dim = static_cast(dimension); if (dim >= expressions_.size()) { @@ -113,14 +114,16 @@ class TensorShapeRep { } xla::DExpr get_filled_expression(int64_t dimension) const { - if (dimension < 0) return xla::DExpr::Unknown(); + if (dimension < 0) { + return xla::DExpr::Unknown(xla::kMissingExpressionSentinel); + } const size_t dim = static_cast(dimension); if (dim < expressions_.size() && expressions_[dim]) { return expressions_[dim]; } if (ndims_byte() == kUnknownRank || dim >= ndims_byte()) { - return xla::DExpr::Unknown(); + return xla::DExpr::Unknown(xla::kMissingExpressionSentinel); } return constant_expression_for_dim(dim); diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 181567cc0fc888..4d880061828909 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -2040,9 +2040,9 @@ class SymbolicShapeRefiner { } bool changed = false; - const bool force_fresh_unknown_dim = node->op() == "Reshape"; std::vector dims; dims.reserve(ic->Rank(s)); + const bool is_reshape = node->op() == "Reshape"; for (int d = 0; d < ic->Rank(s); ++d) { DimensionHandle dim = ic->Dim(s, d); const int64_t v = ic->Value(dim); @@ -2053,7 +2053,10 @@ class SymbolicShapeRefiner { } // If already tagged with expr, keep it. auto* dim_expr = ic->GetDimExpr(dim); - if (dim_expr != nullptr && !force_fresh_unknown_dim) { + const bool refresh_reshape_expr = + is_reshape && dim_expr != nullptr && + dim_expr->kind() != DimExpr::Kind::kVariable; + if (dim_expr != nullptr && !refresh_reshape_expr) { dims.push_back(dim); continue; } diff --git a/third_party/xla/xla/shape.cc b/third_party/xla/xla/shape.cc index 5e256a6303d6b1..22ea27213d8d8c 100644 --- a/third_party/xla/xla/shape.cc +++ b/third_party/xla/xla/shape.cc @@ -42,7 +42,7 @@ limitations under the License. namespace xla { const DExpr& Shape::MissingExpression() { - static const DExpr missing = DExpr::Unknown(); + static const DExpr missing = DExpr::Unknown(kMissingExpressionSentinel); return missing; } diff --git a/third_party/xla/xla/shape_dynexpr.h b/third_party/xla/xla/shape_dynexpr.h index a5713d92fbfc2b..6722a77d50971b 100644 --- a/third_party/xla/xla/shape_dynexpr.h +++ b/third_party/xla/xla/shape_dynexpr.h @@ -31,6 +31,10 @@ limitations under the License. namespace xla { +// Reserved sentinel for "missing expression". Keep this outside the normal +// expression id space so it cannot be confused with a real UnknownExpr id. +inline constexpr int kMissingExpressionSentinel = -1000001; + enum class DExprKind { kUnknown, kConstant, @@ -160,6 +164,10 @@ class UnknownExpr : public DynExpr { } DExprKind kind() const override { return DExprKind::kUnknown; } void print(xla::Printer* printer) const override { + if (id_ == kMissingExpressionSentinel) { + printer->Append("_"); + return; + } printer->Append("?"); if (id_ != 0) { printer->Append(id_); @@ -577,7 +585,7 @@ inline DExpr DExprFromProto(const xla::ExpressionProto& proto) { } case ExpressionProto::NODE_TYPE_NOT_SET: default: - return DExpr::Unknown(); + return DExpr::Unknown(kMissingExpressionSentinel); } }