Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions tensorflow/compiler/tf2xla/ops/xla_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/flags.h"
#include "xla/service/shape_inference.h"
#include "xla/shape.h"
#include "xla/xla_data.pb.h"
Expand Down Expand Up @@ -1144,12 +1145,15 @@ xla::Shape GetShape(shape_inference::ShapeHandle shape_handle,
std::vector<int64_t> dims;
std::vector<bool> dynamic_dims;
std::vector<xla::DExpr> expressions;
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
for (int i = 0, rank = c->Rank(shape_handle); i < rank; ++i) {
bool is_dynamic = !c->ValueKnown(c->Dim(shape_handle, i));
int dynamic_multiplier = c->DynamicRatio(c->Dim(shape_handle, i));
dynamic_dims.push_back(is_dynamic);
expressions.push_back(xla::DExpr::Const(dynamic_multiplier) *
xla::DExpr::Var(1));
if (flags->tf_xla_enable_dynamic_sizes) {
expressions.push_back(xla::DExpr::Const(dynamic_multiplier) *
xla::DExpr::Var(1));
}
dims.push_back(is_dynamic ? xla::Shape::kUnboundedSize
: c->Value(c->Dim(shape_handle, i)));
}
Expand All @@ -1158,7 +1162,9 @@ xla::Shape GetShape(shape_inference::ShapeHandle shape_handle,
xla::PrimitiveType::S64, dims,
absl::InlinedVector<bool, 4>(dynamic_dims.begin(), dynamic_dims.end()));

sh.set_expressions(expressions);
if (flags->tf_xla_enable_dynamic_sizes) {
sh.set_expressions(expressions);
}
return sh;
}

Expand Down
38 changes: 29 additions & 9 deletions tensorflow/compiler/tf2xla/shape_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "xla/layout_util.h"
#include "xla/shape_util.h"
Expand Down Expand Up @@ -100,9 +101,12 @@ absl::Status XLAShapeToTensorShape(const xla::Shape& shape,
for (int i = 0; i < shape.dimensions().size(); ++i) {
TF_RETURN_IF_ERROR(tensor_shape->AddDimWithStatus(shape.dimensions(i)));
}
std::vector<xla::DExpr> dexprs(shape.expressions().begin(),
shape.expressions().end());
tensor_shape->set_expressions(std::move(dexprs));
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
if (flags->tf_xla_enable_dynamic_sizes) {
std::vector<xla::DExpr> dexprs(shape.expressions().begin(),
shape.expressions().end());
tensor_shape->set_expressions(std::move(dexprs));
}
return absl::OkStatus();
}

Expand Down Expand Up @@ -170,22 +174,30 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
}
int rank = tensor_shape.dims();
std::vector<int64_t> dimensions(rank);
std::vector<xla::DExpr> expressions(rank);
std::vector<int64_t> layout(rank);
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
std::vector<xla::DExpr> expressions;
if (flags->tf_xla_enable_dynamic_sizes) {
expressions.resize(rank);
}
for (int d = 0; d < rank; ++d) {
dimensions[d] = tensor_shape.dim_size(d);
if (dimensions[d] < 0) {
LOG(WARNING) << "Unable to convert TF shape with dynamic size to XLA "
"shape; returning unknown sentinel value";
return xla::ShapeUtil::MakeShapeWithDenseLayout(type, {0}, {0});
}
expressions[d] = tensor_shape.get_filled_expression(d);
if (flags->tf_xla_enable_dynamic_sizes) {
expressions[d] = tensor_shape.get_filled_expression(d);
}
}
// XLA uses minor-to-major; Tensorflow uses major-to-minor.
std::iota(layout.rbegin(), layout.rend(), 0);
xla::Shape result =
xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout);
result.set_expressions(expressions);
if (flags->tf_xla_enable_dynamic_sizes) {
result.set_expressions(expressions);
}
return result;
}

Expand Down Expand Up @@ -213,19 +225,27 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
int rank = tensor_shape.dims();
std::vector<int64_t> dimensions(rank);
std::vector<int64_t> layout(rank);
std::vector<xla::DExpr> expressions(rank);
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
std::vector<xla::DExpr> expressions;
if (flags->tf_xla_enable_dynamic_sizes) {
expressions.resize(rank);
}

for (int d = 0; d < rank; ++d) {
dimensions[d] = tensor_shape.dim_size(d);
expressions[d] = tensor_shape.get_filled_expression(d);
if (flags->tf_xla_enable_dynamic_sizes) {
expressions[d] = tensor_shape.get_filled_expression(d);
}
}

// XLA uses minor-to-major; Tensorflow uses major-to-minor.
std::iota(layout.rbegin(), layout.rend(), 0);

auto shape =
xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout);
shape.set_expressions(expressions);
if (flags->tf_xla_enable_dynamic_sizes) {
shape.set_expressions(expressions);
}
return shape;
}

Expand Down
30 changes: 23 additions & 7 deletions tensorflow/core/framework/tensor_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ xla::DExpr DExprFromProto(const ExpressionProto& proto) {
}
case ExpressionProto::NODE_TYPE_NOT_SET:
default:
return xla::DExpr::Unknown();
return xla::DExpr::Unknown(xla::kMissingExpressionSentinel);
}
}

Expand Down Expand Up @@ -485,17 +485,21 @@ void TensorShapeRep::set_expression(int d, xla::DExpr expr) {
return;
}
if (expressions_.size() <= static_cast<size_t>(d)) {
expressions_.resize(d + 1, xla::DExpr::Unknown());
expressions_.resize(d + 1,
xla::DExpr::Unknown(xla::kMissingExpressionSentinel));
}
expressions_[d] = expr ? std::move(expr) : xla::DExpr::Unknown();
expressions_[d] = expr ? std::move(expr)
: xla::DExpr::Unknown(xla::kMissingExpressionSentinel);
}

void TensorShapeRep::AddExpression(xla::DExpr expr) {
if (!kTensorShapeExpressionsEnabled) {
return;
}
CHECK_LT(expressions_.size(), ndims_byte());
expressions_.push_back(expr ? std::move(expr) : xla::DExpr::Unknown());
expressions_.push_back(expr ? std::move(expr)
: xla::DExpr::Unknown(
xla::kMissingExpressionSentinel));
}

void TensorShapeRep::set_expressions(std::vector<xla::DExpr> exprs) {
Expand All @@ -504,7 +508,7 @@ void TensorShapeRep::set_expressions(std::vector<xla::DExpr> exprs) {
return;
}
for (auto& expr : exprs) {
if (!expr) expr = xla::DExpr::Unknown();
if (!expr) expr = xla::DExpr::Unknown(xla::kMissingExpressionSentinel);
}
expressions_ = std::move(exprs);
}
Expand Down Expand Up @@ -723,7 +727,13 @@ template <class Shape>
void TensorShapeBase<Shape>::set_dim(int d, int64_t size) {
CHECK_GE(d, 0);
CHECK_LT(d, dims());
if (get_expressions().size() > d) set_expression(d, xla::DExpr::Const(size));
// After DExpr migration, missing slots may be normalized to Unknown().
// Preserve those placeholders here instead of materializing them into a
// concrete constant just because the dimension size changed.
if (get_expressions().size() > d &&
get_expression(d).kind() != xla::DExpr::Kind::kUnknown) {
set_expression(d, xla::DExpr::Const(size));
}
if (!kIsPartial) {
CHECK_GE(size, 0);
}
Expand Down Expand Up @@ -785,7 +795,13 @@ absl::Status TensorShapeBase<Shape>::SetDimWithStatus(int d, int64_t size) {
}
}

if (get_expressions().size() > d) set_expression(d, xla::DExpr::Const(size));
// After DExpr migration, missing slots may be normalized to Unknown().
// Preserve those placeholders here instead of materializing them into a
// concrete constant just because the dimension size changed.
if (get_expressions().size() > d &&
get_expression(d).kind() != xla::DExpr::Kind::kUnknown) {
set_expression(d, xla::DExpr::Const(size));
}
return RecomputeNumElements();
}

Expand Down
9 changes: 6 additions & 3 deletions tensorflow/core/framework/tensor_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ class TensorShapeRep {
// Return the multiplier for a specific dynamic dimension.
// -1 if the dimension is not dynamic.
const xla::DExpr& get_expression(int64_t dimension) const {
static const xla::DExpr kMissingExpression = xla::DExpr::Unknown();
static const xla::DExpr kMissingExpression =
xla::DExpr::Unknown(xla::kMissingExpressionSentinel);
if (dimension < 0) return kMissingExpression;
const size_t dim = static_cast<size_t>(dimension);
if (dim >= expressions_.size()) {
Expand All @@ -113,14 +114,16 @@ class TensorShapeRep {
}

xla::DExpr get_filled_expression(int64_t dimension) const {
if (dimension < 0) return xla::DExpr::Unknown();
if (dimension < 0) {
return xla::DExpr::Unknown(xla::kMissingExpressionSentinel);
}
const size_t dim = static_cast<size_t>(dimension);
if (dim < expressions_.size() && expressions_[dim]) {
return expressions_[dim];
}

if (ndims_byte() == kUnknownRank || dim >= ndims_byte()) {
return xla::DExpr::Unknown();
return xla::DExpr::Unknown(xla::kMissingExpressionSentinel);
}

return constant_expression_for_dim(dim);
Expand Down
7 changes: 5 additions & 2 deletions tensorflow/core/grappler/costs/graph_properties.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2040,9 +2040,9 @@ class SymbolicShapeRefiner {
}

bool changed = false;
const bool force_fresh_unknown_dim = node->op() == "Reshape";
std::vector<DimensionHandle> dims;
dims.reserve(ic->Rank(s));
const bool is_reshape = node->op() == "Reshape";
for (int d = 0; d < ic->Rank(s); ++d) {
DimensionHandle dim = ic->Dim(s, d);
const int64_t v = ic->Value(dim);
Expand All @@ -2053,7 +2053,10 @@ class SymbolicShapeRefiner {
}
// If already tagged with expr, keep it.
auto* dim_expr = ic->GetDimExpr(dim);
if (dim_expr != nullptr && !force_fresh_unknown_dim) {
const bool refresh_reshape_expr =
is_reshape && dim_expr != nullptr &&
dim_expr->kind() != DimExpr::Kind::kVariable;
if (dim_expr != nullptr && !refresh_reshape_expr) {
dims.push_back(dim);
continue;
}
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ limitations under the License.
namespace xla {

const DExpr& Shape::MissingExpression() {
static const DExpr missing = DExpr::Unknown();
static const DExpr missing = DExpr::Unknown(kMissingExpressionSentinel);
return missing;
}

Expand Down
10 changes: 9 additions & 1 deletion third_party/xla/xla/shape_dynexpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ limitations under the License.

namespace xla {

// Reserved sentinel for "missing expression". Keep this outside the normal
// expression id space so it cannot be confused with a real UnknownExpr id.
inline constexpr int kMissingExpressionSentinel = -1000001;

enum class DExprKind {
kUnknown,
kConstant,
Expand Down Expand Up @@ -160,6 +164,10 @@ class UnknownExpr : public DynExpr {
}
DExprKind kind() const override { return DExprKind::kUnknown; }
void print(xla::Printer* printer) const override {
if (id_ == kMissingExpressionSentinel) {
printer->Append("_");
return;
}
printer->Append("?");
if (id_ != 0) {
printer->Append(id_);
Expand Down Expand Up @@ -577,7 +585,7 @@ inline DExpr DExprFromProto(const xla::ExpressionProto& proto) {
}
case ExpressionProto::NODE_TYPE_NOT_SET:
default:
return DExpr::Unknown();
return DExpr::Unknown(kMissingExpressionSentinel);
}
}

Expand Down