Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensorflow/compiler/jit/kernels/xla_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,9 @@ absl::Status CompileToLocalExecutable(
int64_t old = shp.dim_size(j);
old_vars.push_back({i, j, old});
xla::DExpr padded_expr = xla::DExpr::Const(filled_batch);
// TODO: If fractional expressions are allowed to
// survive until padding substitution, validate integrality before
// calling get_val() here instead of assuming simplify() is exact.
xla::DExpr subst_expr = e.substitute(1, padded_expr).simplify();
int64_t new_dim = subst_expr->get_val();
if (new_dim >= 0) {
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/compiler/jit/xla_launch_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,9 @@ absl::Status XlaComputationLaunchContext::PopulateOutputs(
VLOG(1) << "Current expression is " << expr;
if (run_options) {
xla::DExpr batch_size = xla::DExpr::Const(run_options->batch_size());
// TODO: If fractional expressions are allowed to
// survive until runtime substitution, validate integrality before
// calling get_val() here instead of assuming simplify() is exact.
xla::DExpr subst_expr = expr.substitute(1, batch_size).simplify();
shape.set_dim(dim, subst_expr->get_val());
} else {
Expand All @@ -458,6 +461,9 @@ absl::Status XlaComputationLaunchContext::PopulateOutputs(
ctx->resource_manager(), BatchSizeResourceName, &bsr));
xla::DExpr batch_size = xla::DExpr::Const(bsr->GetBatchSize());
// Just substitute Var(1) for now.
// TODO: If fractional expressions are allowed to
// survive until runtime substitution, validate integrality before
// calling get_val() here instead of assuming simplify() is exact.
xla::DExpr subst_expr = expr.substitute(1, batch_size).simplify();
shape.set_dim(dim, subst_expr->get_val());
bsr->Unref();
Expand Down
124 changes: 112 additions & 12 deletions third_party/xla/xla/shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.

#include <algorithm>
#include <cstdint>
#include <cstdlib>
#include <numeric>
#include <optional>
#include <ostream>
#include <string>
Expand All @@ -41,6 +43,38 @@ limitations under the License.

namespace xla {

namespace {

Constant* AsConstant(DynExpr* expr) {
return expr != nullptr && expr->kind() == DExpr::Kind::kConstant
? static_cast<Constant*>(expr)
: nullptr;
}

void NormalizeFraction(int64_t* numerator, int64_t* denominator) {
CHECK(denominator != nullptr);
CHECK(*denominator != 0);
int64_t divisor = std::gcd(std::llabs(*numerator), std::llabs(*denominator));
if (divisor > 1) {
*numerator /= divisor;
*denominator /= divisor;
}
if (*denominator < 0) {
*numerator = -*numerator;
*denominator = -*denominator;
}
}

std::unique_ptr<DynExpr> MultiplyByConstant(int64_t factor, DynExpr* expr) {
CHECK(expr != nullptr);
if (factor == 1) {
return expr->clone();
}
return std::make_unique<Mul>(DynExpr::_(factor), expr->clone().release());
}

} // namespace

const DExpr& Shape::MissingExpression() {
static const DExpr missing = DExpr::Unknown();
return missing;
Expand Down Expand Up @@ -180,6 +214,43 @@ DynExpr* Mul::s() {
auto reordered = std::unique_ptr<DynExpr>(r->get_val() * *s_lhs);
return reordered->s();
}
// (c / d) * X = (c * X) / d, but keep the reduced fraction symbolic.
if (s_lhs->kind() == DExpr::Kind::kDiv &&
s_rhs->kind() != DExpr::Kind::kDiv) {
auto* div = static_cast<Div*>(s_lhs.get());
Constant* numerator = AsConstant(div->get_lhs());
Constant* denominator = AsConstant(div->get_rhs());
if (numerator != nullptr && denominator != nullptr) {
int64_t reduced_numerator = numerator->get_val();
int64_t reduced_denominator = denominator->get_val();
NormalizeFraction(&reduced_numerator, &reduced_denominator);
auto scaled_rhs = MultiplyByConstant(reduced_numerator, s_rhs.get());
if (reduced_denominator == 1) {
return scaled_rhs->s();
}
auto rewritten =
std::make_unique<Div>(scaled_rhs->s(), DynExpr::_(reduced_denominator));
return rewritten->s();
}
}
if (s_rhs->kind() == DExpr::Kind::kDiv &&
s_lhs->kind() != DExpr::Kind::kDiv) {
auto* div = static_cast<Div*>(s_rhs.get());
Constant* numerator = AsConstant(div->get_lhs());
Constant* denominator = AsConstant(div->get_rhs());
if (numerator != nullptr && denominator != nullptr) {
int64_t reduced_numerator = numerator->get_val();
int64_t reduced_denominator = denominator->get_val();
NormalizeFraction(&reduced_numerator, &reduced_denominator);
auto scaled_lhs = MultiplyByConstant(reduced_numerator, s_lhs.get());
if (reduced_denominator == 1) {
return scaled_lhs->s();
}
auto rewritten =
std::make_unique<Div>(scaled_lhs->s(), DynExpr::_(reduced_denominator));
return rewritten->s();
}
}
// m * (nX) = (m*n) * X
if (s_rhs->kind() == DExpr::Kind::kMul) {
auto* nX = static_cast<Mul*>(s_rhs.get());
Expand Down Expand Up @@ -349,14 +420,21 @@ DynExpr* Div::s() {
s_rhs->kind() == DExpr::Kind::kUnknown) {
return DExpr::Unknown().release();
}
Constant* l = s_lhs->kind() == DExpr::Kind::kConstant
? static_cast<Constant*>(s_lhs.get())
: nullptr;
Constant* r = s_rhs->kind() == DExpr::Kind::kConstant
? static_cast<Constant*>(s_rhs.get())
: nullptr;
Constant* l = AsConstant(s_lhs.get());
Constant* r = AsConstant(s_rhs.get());
if (l && l->get_val() == 0) return DynExpr::_(0);
// constant / constant
if (l && r) return DynExpr::_(l->get_val() / r->get_val());
if (l && r) {
int64_t numerator = l->get_val();
int64_t denominator = r->get_val();
NormalizeFraction(&numerator, &denominator);
if (denominator == 1) {
return DynExpr::_(numerator);
}
return std::make_unique<Div>(DynExpr::_(numerator),
DynExpr::_(denominator))
.release();
}
// X / 1 = X
if (r && r->get_val() == 1) return s_lhs.release();
// (X + Y) / Z = (X/Z) + (Y/Z)
Expand All @@ -369,14 +447,36 @@ DynExpr* Div::s() {
auto distributed = std::unique_ptr<DynExpr>(*left + *right);
return distributed->s();
}
// (X * Y) / Z = (X/Z) * Y
if (s_lhs->kind() == DExpr::Kind::kMul) {
// (c * X) / d = (c/g * X) / (d/g), keeping any non-integral division
// symbolic instead of truncating c / d to zero.
if (r && s_lhs->kind() == DExpr::Kind::kMul) {
auto* XY = static_cast<Mul*>(s_lhs.get());
DynExpr* X = XY->get_lhs();
DynExpr* Y = XY->get_rhs();
auto left = std::unique_ptr<DynExpr>(*X / *s_rhs);
auto distributed = std::unique_ptr<DynExpr>(*left * (*Y));
return distributed->s();
if (Constant* c = AsConstant(X)) {
int64_t numerator = c->get_val();
int64_t denominator = r->get_val();
NormalizeFraction(&numerator, &denominator);
auto scaled = MultiplyByConstant(numerator, Y);
if (denominator == 1) {
return scaled->s();
}
auto rewritten =
std::make_unique<Div>(scaled->s(), DynExpr::_(denominator));
return rewritten->s();
}
if (Constant* c = AsConstant(Y)) {
int64_t numerator = c->get_val();
int64_t denominator = r->get_val();
NormalizeFraction(&numerator, &denominator);
auto scaled = MultiplyByConstant(numerator, X);
if (denominator == 1) {
return scaled->s();
}
auto rewritten =
std::make_unique<Div>(scaled->s(), DynExpr::_(denominator));
return rewritten->s();
}
}
// (X / Y) / Z = X / (Y*Z)
if (s_lhs->kind() == DExpr::Kind::kDiv) {
Expand Down
25 changes: 15 additions & 10 deletions third_party/xla/xla/shape_dynexpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,11 @@ class Add : public DynExpr {
std::optional<int64_t> solve(int64_t x) {
// Cannot solve if both lhs and rhs are dynamic...
if (lhs->is_dynamic() && rhs->is_dynamic()) return std::nullopt;
if (lhs->get_all_ids().size() == 1) {
if (lhs->get_all_ids().size() == 1 && rhs->is_constant()) {
// (A + c) = x <=> A = x - c => solve A = y with y = x - c
return lhs->solve(x - rhs->get_val());
}
if (rhs->get_all_ids().size() == 1) {
if (rhs->get_all_ids().size() == 1 && lhs->is_constant()) {
// (c + A) = x <=> A = x - c => solve A = y with y = x - c
return rhs->solve(x - lhs->get_val());
}
Expand Down Expand Up @@ -357,11 +357,11 @@ class Sub : public DynExpr {
std::optional<int64_t> solve(int64_t x) {
// Cannot solve if both lhs and rhs are dynamic...
if (lhs->is_dynamic() && rhs->is_dynamic()) return std::nullopt;
if (lhs->get_all_ids().size() == 1) {
if (lhs->get_all_ids().size() == 1 && rhs->is_constant()) {
// (A - c) = x <=> A = x + c => solve A = y with y = x + c
return lhs->solve(x + rhs->get_val());
}
if (rhs->get_all_ids().size() == 1) {
if (rhs->get_all_ids().size() == 1 && lhs->is_constant()) {
// (c - A) = x <=> A = c - x => solve A = y with y = c - x
return rhs->solve(lhs->get_val() - x);
}
Expand Down Expand Up @@ -421,14 +421,14 @@ class Mul : public DynExpr {
std::optional<int64_t> solve(int64_t x) {
// Cannot solve if both lhs and rhs are dynamic...
if (lhs->is_dynamic() && rhs->is_dynamic()) return std::nullopt;
if (lhs->get_all_ids().size() == 1) {
if (lhs->get_all_ids().size() == 1 && rhs->is_constant()) {
// (A * c) = x <=> A = x / c => solve A = y with y = x / c
int64_t c = rhs->get_val();
if (c == 0) return x == 0 ? lhs->solve(0) : std::nullopt;
if (x % c != 0) return std::nullopt;
return lhs->solve(x / c);
}
if (rhs->get_all_ids().size() == 1) {
if (rhs->get_all_ids().size() == 1 && lhs->is_constant()) {
// (c * A) = x <=> A = x / c => solve A = y with y = x / c
int64_t c = lhs->get_val();
if (c == 0) return x == 0 ? rhs->solve(0) : std::nullopt;
Expand Down Expand Up @@ -473,10 +473,15 @@ class Div : public DynExpr {
}

bool is_constant() const override {
return lhs->is_constant() && rhs->is_constant();
return lhs->is_constant() && rhs->is_constant() && rhs->get_val() != 0 &&
lhs->get_val() % rhs->get_val() == 0;
}

int64_t get_val() const override { return lhs->get_val() / rhs->get_val(); }
int64_t get_val() const override {
CHECK(is_constant()) << "Attempted to get integer value of non-integral "
<< "division expression";
return lhs->get_val() / rhs->get_val();
}

DynExpr* substitute(int id, DynExpr* v) {
return new Div(lhs->substitute(id, v), rhs->substitute(id, v));
Expand All @@ -493,11 +498,11 @@ class Div : public DynExpr {
std::optional<int64_t> solve(int64_t x) {
// Cannot solve if both lhs and rhs are dynamic...
if (lhs->is_dynamic() && rhs->is_dynamic()) return std::nullopt;
if (lhs->get_all_ids().size() == 1) {
if (lhs->get_all_ids().size() == 1 && rhs->is_constant()) {
// (A / c) = x <=> A = x * c => solve A = y with y = x * c
return lhs->solve(x * rhs->get_val());
}
if (rhs->get_all_ids().size() == 1) {
if (rhs->get_all_ids().size() == 1 && lhs->is_constant()) {
// (c / A) = x <=> A = c / x => solve A = y with y = c / x
int64_t c = lhs->get_val();
if (x == 0) return std::nullopt;
Expand Down