diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 782ac342fc8991..8629635a0536f2 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -691,6 +691,9 @@ absl::Status CompileToLocalExecutable( int64_t old = shp.dim_size(j); old_vars.push_back({i, j, old}); xla::DExpr padded_expr = xla::DExpr::Const(filled_batch); + // TODO: If fractional expressions are allowed to + // survive until padding substitution, validate integrality before + // calling get_val() here instead of assuming simplify() is exact. xla::DExpr subst_expr = e.substitute(1, padded_expr).simplify(); int64_t new_dim = subst_expr->get_val(); if (new_dim >= 0) { diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index ccf97cb24d75f8..5ea00a8a585909 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -447,6 +447,9 @@ absl::Status XlaComputationLaunchContext::PopulateOutputs( VLOG(1) << "Current expression is " << expr; if (run_options) { xla::DExpr batch_size = xla::DExpr::Const(run_options->batch_size()); + // TODO: If fractional expressions are allowed to + // survive until runtime substitution, validate integrality before + // calling get_val() here instead of assuming simplify() is exact. xla::DExpr subst_expr = expr.substitute(1, batch_size).simplify(); shape.set_dim(dim, subst_expr->get_val()); } else { @@ -458,6 +461,9 @@ absl::Status XlaComputationLaunchContext::PopulateOutputs( ctx->resource_manager(), BatchSizeResourceName, &bsr)); xla::DExpr batch_size = xla::DExpr::Const(bsr->GetBatchSize()); // Just substitute Var(1) for now. + // TODO: If fractional expressions are allowed to + // survive until runtime substitution, validate integrality before + // calling get_val() here instead of assuming simplify() is exact. xla::DExpr subst_expr = expr.substitute(1, batch_size).simplify(); shape.set_dim(dim, subst_expr->get_val()); bsr->Unref(); diff --git a/third_party/xla/xla/shape.cc b/third_party/xla/xla/shape.cc index 5e256a6303d6b1..bbb46add7d7af2 100644 --- a/third_party/xla/xla/shape.cc +++ b/third_party/xla/xla/shape.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include +#include #include #include #include @@ -41,6 +43,38 @@ limitations under the License. namespace xla { +namespace { + +Constant* AsConstant(DynExpr* expr) { + return expr != nullptr && expr->kind() == DExpr::Kind::kConstant + ? static_cast(expr) + : nullptr; +} + +void NormalizeFraction(int64_t* numerator, int64_t* denominator) { + CHECK(denominator != nullptr); + CHECK(*denominator != 0); + int64_t divisor = std::gcd(std::llabs(*numerator), std::llabs(*denominator)); + if (divisor > 1) { + *numerator /= divisor; + *denominator /= divisor; + } + if (*denominator < 0) { + *numerator = -*numerator; + *denominator = -*denominator; + } +} + +std::unique_ptr MultiplyByConstant(int64_t factor, DynExpr* expr) { + CHECK(expr != nullptr); + if (factor == 1) { + return expr->clone(); + } + return std::make_unique(DynExpr::_(factor), expr->clone().release()); +} + +} // namespace + const DExpr& Shape::MissingExpression() { static const DExpr missing = DExpr::Unknown(); return missing; @@ -180,6 +214,43 @@ DynExpr* Mul::s() { auto reordered = std::unique_ptr(r->get_val() * *s_lhs); return reordered->s(); } + // (c / d) * X = (c * X) / d, but keep the reduced fraction symbolic. + if (s_lhs->kind() == DExpr::Kind::kDiv && + s_rhs->kind() != DExpr::Kind::kDiv) { + auto* div = static_cast(s_lhs.get()); + Constant* numerator = AsConstant(div->get_lhs()); + Constant* denominator = AsConstant(div->get_rhs()); + if (numerator != nullptr && denominator != nullptr) { + int64_t reduced_numerator = numerator->get_val(); + int64_t reduced_denominator = denominator->get_val(); + NormalizeFraction(&reduced_numerator, &reduced_denominator); + auto scaled_rhs = MultiplyByConstant(reduced_numerator, s_rhs.get()); + if (reduced_denominator == 1) { + return scaled_rhs->s(); + } + auto rewritten = + std::make_unique
(scaled_rhs->s(), DynExpr::_(reduced_denominator)); + return rewritten->s(); + } + } + if (s_rhs->kind() == DExpr::Kind::kDiv && + s_lhs->kind() != DExpr::Kind::kDiv) { + auto* div = static_cast(s_rhs.get()); + Constant* numerator = AsConstant(div->get_lhs()); + Constant* denominator = AsConstant(div->get_rhs()); + if (numerator != nullptr && denominator != nullptr) { + int64_t reduced_numerator = numerator->get_val(); + int64_t reduced_denominator = denominator->get_val(); + NormalizeFraction(&reduced_numerator, &reduced_denominator); + auto scaled_lhs = MultiplyByConstant(reduced_numerator, s_lhs.get()); + if (reduced_denominator == 1) { + return scaled_lhs->s(); + } + auto rewritten = + std::make_unique
(scaled_lhs->s(), DynExpr::_(reduced_denominator)); + return rewritten->s(); + } + } // m * (nX) = (m*n) * X if (s_rhs->kind() == DExpr::Kind::kMul) { auto* nX = static_cast(s_rhs.get()); @@ -349,14 +420,21 @@ DynExpr* Div::s() { s_rhs->kind() == DExpr::Kind::kUnknown) { return DExpr::Unknown().release(); } - Constant* l = s_lhs->kind() == DExpr::Kind::kConstant - ? static_cast(s_lhs.get()) - : nullptr; - Constant* r = s_rhs->kind() == DExpr::Kind::kConstant - ? static_cast(s_rhs.get()) - : nullptr; + Constant* l = AsConstant(s_lhs.get()); + Constant* r = AsConstant(s_rhs.get()); + if (l && l->get_val() == 0) return DynExpr::_(0); // constant / constant - if (l && r) return DynExpr::_(l->get_val() / r->get_val()); + if (l && r) { + int64_t numerator = l->get_val(); + int64_t denominator = r->get_val(); + NormalizeFraction(&numerator, &denominator); + if (denominator == 1) { + return DynExpr::_(numerator); + } + return std::make_unique
(DynExpr::_(numerator), + DynExpr::_(denominator)) + .release(); + } // X / 1 = X if (r && r->get_val() == 1) return s_lhs.release(); // (X + Y) / Z = (X/Z) + (Y/Z) @@ -369,14 +447,36 @@ DynExpr* Div::s() { auto distributed = std::unique_ptr(*left + *right); return distributed->s(); } - // (X * Y) / Z = (X/Z) * Y - if (s_lhs->kind() == DExpr::Kind::kMul) { + // (c * X) / d = (c/g * X) / (d/g), keeping any non-integral division + // symbolic instead of truncating c / d to zero. + if (r && s_lhs->kind() == DExpr::Kind::kMul) { auto* XY = static_cast(s_lhs.get()); DynExpr* X = XY->get_lhs(); DynExpr* Y = XY->get_rhs(); - auto left = std::unique_ptr(*X / *s_rhs); - auto distributed = std::unique_ptr(*left * (*Y)); - return distributed->s(); + if (Constant* c = AsConstant(X)) { + int64_t numerator = c->get_val(); + int64_t denominator = r->get_val(); + NormalizeFraction(&numerator, &denominator); + auto scaled = MultiplyByConstant(numerator, Y); + if (denominator == 1) { + return scaled->s(); + } + auto rewritten = + std::make_unique
(scaled->s(), DynExpr::_(denominator)); + return rewritten->s(); + } + if (Constant* c = AsConstant(Y)) { + int64_t numerator = c->get_val(); + int64_t denominator = r->get_val(); + NormalizeFraction(&numerator, &denominator); + auto scaled = MultiplyByConstant(numerator, X); + if (denominator == 1) { + return scaled->s(); + } + auto rewritten = + std::make_unique
(scaled->s(), DynExpr::_(denominator)); + return rewritten->s(); + } } // (X / Y) / Z = X / (Y*Z) if (s_lhs->kind() == DExpr::Kind::kDiv) { diff --git a/third_party/xla/xla/shape_dynexpr.h b/third_party/xla/xla/shape_dynexpr.h index a5713d92fbfc2b..97d8ee5a7fb4c5 100644 --- a/third_party/xla/xla/shape_dynexpr.h +++ b/third_party/xla/xla/shape_dynexpr.h @@ -293,11 +293,11 @@ class Add : public DynExpr { std::optional solve(int64_t x) { // Cannot solve if both lhs and rhs are dynamic... if (lhs->is_dynamic() && rhs->is_dynamic()) return std::nullopt; - if (lhs->get_all_ids().size() == 1) { + if (lhs->get_all_ids().size() == 1 && rhs->is_constant()) { // (A + c) = x <=> A = x - c => solve A = y with y = x - c return lhs->solve(x - rhs->get_val()); } - if (rhs->get_all_ids().size() == 1) { + if (rhs->get_all_ids().size() == 1 && lhs->is_constant()) { // (c + A) = x <=> A = x - c => solve A = y with y = x - c return rhs->solve(x - lhs->get_val()); } @@ -357,11 +357,11 @@ class Sub : public DynExpr { std::optional solve(int64_t x) { // Cannot solve if both lhs and rhs are dynamic... if (lhs->is_dynamic() && rhs->is_dynamic()) return std::nullopt; - if (lhs->get_all_ids().size() == 1) { + if (lhs->get_all_ids().size() == 1 && rhs->is_constant()) { // (A - c) = x <=> A = x + c => solve A = y with y = x + c return lhs->solve(x + rhs->get_val()); } - if (rhs->get_all_ids().size() == 1) { + if (rhs->get_all_ids().size() == 1 && lhs->is_constant()) { // (c - A) = x <=> A = c - x => solve A = y with y = c - x return rhs->solve(lhs->get_val() - x); } @@ -421,14 +421,14 @@ class Mul : public DynExpr { std::optional solve(int64_t x) { // Cannot solve if both lhs and rhs are dynamic... if (lhs->is_dynamic() && rhs->is_dynamic()) return std::nullopt; - if (lhs->get_all_ids().size() == 1) { + if (lhs->get_all_ids().size() == 1 && rhs->is_constant()) { // (A * c) = x <=> A = x / c => solve A = y with y = x / c int64_t c = rhs->get_val(); if (c == 0) return x == 0 ? lhs->solve(0) : std::nullopt; if (x % c != 0) return std::nullopt; return lhs->solve(x / c); } - if (rhs->get_all_ids().size() == 1) { + if (rhs->get_all_ids().size() == 1 && lhs->is_constant()) { // (c * A) = x <=> A = x / c => solve A = y with y = x / c int64_t c = lhs->get_val(); if (c == 0) return x == 0 ? rhs->solve(0) : std::nullopt; @@ -473,10 +473,15 @@ class Div : public DynExpr { } bool is_constant() const override { - return lhs->is_constant() && rhs->is_constant(); + return lhs->is_constant() && rhs->is_constant() && rhs->get_val() != 0 && + lhs->get_val() % rhs->get_val() == 0; } - int64_t get_val() const override { return lhs->get_val() / rhs->get_val(); } + int64_t get_val() const override { + CHECK(is_constant()) << "Attempted to get integer value of non-integral " + << "division expression"; + return lhs->get_val() / rhs->get_val(); + } DynExpr* substitute(int id, DynExpr* v) { return new Div(lhs->substitute(id, v), rhs->substitute(id, v)); @@ -493,11 +498,11 @@ class Div : public DynExpr { std::optional solve(int64_t x) { // Cannot solve if both lhs and rhs are dynamic... if (lhs->is_dynamic() && rhs->is_dynamic()) return std::nullopt; - if (lhs->get_all_ids().size() == 1) { + if (lhs->get_all_ids().size() == 1 && rhs->is_constant()) { // (A / c) = x <=> A = x * c => solve A = y with y = x * c return lhs->solve(x * rhs->get_val()); } - if (rhs->get_all_ids().size() == 1) { + if (rhs->get_all_ids().size() == 1 && lhs->is_constant()) { // (c / A) = x <=> A = c / x => solve A = y with y = c / x int64_t c = lhs->get_val(); if (x == 0) return std::nullopt;