Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 0 additions & 25 deletions tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -494,31 +494,6 @@ Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) {

Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); }

void ExprToProto(xla::DynExpr* expr, ExpressionProto* proto) {
auto e = expr->s();
if (xla::Constant* c = dynamic_cast<xla::Constant*>(e)) {
proto->set_constant_value(c->get_val());
} else if (xla::Variable* v = dynamic_cast<xla::Variable*>(e)) {
proto->set_variable_id(v->get_id());
} else if (xla::Add* a = dynamic_cast<xla::Add*>(e)) {
auto* add_msg = proto->mutable_add_node();
ExprToProto(a->get_lhs(), add_msg->mutable_lhs());
ExprToProto(a->get_rhs(), add_msg->mutable_rhs());
} else if (xla::Mul* m = dynamic_cast<xla::Mul*>(e)) {
auto* mul_msg = proto->mutable_mul_node();
ExprToProto(m->get_lhs(), mul_msg->mutable_lhs());
ExprToProto(m->get_rhs(), mul_msg->mutable_rhs());
} else if (xla::Sub* s = dynamic_cast<xla::Sub*>(e)) {
auto* sub_msg = proto->mutable_sub_node();
ExprToProto(s->get_lhs(), sub_msg->mutable_lhs());
ExprToProto(s->get_rhs(), sub_msg->mutable_rhs());
} else if (xla::Div* d = dynamic_cast<xla::Div*>(e)) {
auto* div_msg = proto->mutable_div_node();
ExprToProto(d->get_lhs(), div_msg->mutable_lhs());
ExprToProto(d->get_rhs(), div_msg->mutable_rhs());
}
}

absl::Status Encapsulator::Subgraph::RecordArg(
const Edge* edge,
const absl::flat_hash_map<const Node*, Node*>& node_images,
Expand Down
49 changes: 24 additions & 25 deletions tensorflow/compiler/jit/kernels/xla_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -419,34 +419,33 @@ std::unique_ptr<DimExpr> ExprFromProto(const ExpressionProto& proto) {
}
}

static xla::DynExpr* DimExprToDynExpr(const DimExpr* e) {
static xla::DExpr DimExprToDExpr(const DimExpr* e) {
switch (e->kind()) {
case DimExpr::Kind::kConstant: {
auto* ac = static_cast<const Constant*>(e);
return xla::DynExpr::_(ac->value());
return xla::DExpr::Const(ac->value());
}
case DimExpr::Kind::kVariable: {
auto* av = static_cast<const Variable*>(e);
return xla::DynExpr::V(1);
return xla::DExpr::Var(1);
}
case DimExpr::Kind::kAdd: {
auto* ee = static_cast<const ExprAdd*>(e);
return *DimExprToDynExpr(ee->lhs()) + *DimExprToDynExpr(ee->rhs());
return DimExprToDExpr(ee->lhs()) + DimExprToDExpr(ee->rhs());
}
case DimExpr::Kind::kSub: {
auto* ee = static_cast<const ExprSub*>(e);
return *DimExprToDynExpr(ee->lhs()) - *DimExprToDynExpr(ee->rhs());
return DimExprToDExpr(ee->lhs()) - DimExprToDExpr(ee->rhs());
}
case DimExpr::Kind::kMul: {
auto* ee = static_cast<const ExprMul*>(e);
return *DimExprToDynExpr(ee->lhs()) * *DimExprToDynExpr(ee->rhs());
return DimExprToDExpr(ee->lhs()) * DimExprToDExpr(ee->rhs());
}
case DimExpr::Kind::kDiv: {
auto* ee = static_cast<const ExprDiv*>(e);
return *DimExprToDynExpr(ee->lhs()) / *DimExprToDynExpr(ee->rhs());
return DimExprToDExpr(ee->lhs()) / DimExprToDExpr(ee->rhs());
}
}
return nullptr;
return xla::DExpr::Unknown();
}


Expand Down Expand Up @@ -511,12 +510,12 @@ absl::Status CompileToLocalExecutable(
int64_t dynamic_dim_value = 0;
XlaBatchMatcher* xla_batch_matcher =
xla_device_compiler->xla_batch_matcher();
xla::DynExpr* dynamic_dim_expr = nullptr;
auto record_dynamic_dim_value = [&](int64_t dim_size, xla::DynExpr* expr) {
std::optional<xla::DExpr> dynamic_dim_expr;
auto record_dynamic_dim_value = [&](int64_t dim_size, xla::DExpr expr) {
if (!saw_dynamic_dim_value) {
saw_dynamic_dim_value = true;
dynamic_dim_value = dim_size;
dynamic_dim_expr = expr;
dynamic_dim_expr = std::move(expr);
return;
}
if (dynamic_dim_value != dim_size) {
Expand Down Expand Up @@ -545,18 +544,18 @@ absl::Status CompileToLocalExecutable(
std::get<TensorShape>(norm_args[arg_index].shape);
const AttrValue& v = dyn_dim_attr->second;
int64_t idx = v.i();
record_dynamic_dim_value(shp.dim_size(idx), xla::DynExpr::V(1));
record_dynamic_dim_value(shp.dim_size(idx), xla::DExpr::Var(1));
if (!filled_batch && xla_batch_matcher) {
filled_batch =
xla_batch_matcher->get_xla_compile_batch(shp.dim_size(idx));
}

std::vector<xla::DynExpr*> dyn_exprs;
std::vector<xla::DExpr> dyn_exprs;
for (int d : shp.dim_sizes()) {
dyn_exprs.push_back(xla::DynExpr::_(d));
dyn_exprs.push_back(xla::DExpr::Const(d));
}
dyn_exprs[idx] = xla::DynExpr::V(1);
shp.set_expressions(dyn_exprs);
dyn_exprs[idx] = xla::DExpr::Var(1);
shp.set_expressions(std::move(dyn_exprs));
continue;
}
auto it = attr_map.find(kXlaInferredOutputShapesAttrName);
Expand All @@ -570,7 +569,7 @@ absl::Status CompileToLocalExecutable(
for (int idx = 0; idx < exp.size(); ++idx) {
// Look for dynamic expression. If found then compute padding
// value and exit loop.
auto e = DimExprToDynExpr(ExprFromProto(exp[idx]).get())->s();
auto e = DimExprToDExpr(ExprFromProto(exp[idx]).get()).simplify();
if (e->is_dynamic()) {
std::optional<int64_t> solved_value =
e->solve(shp.dim_size(idx));
Expand All @@ -595,17 +594,17 @@ absl::Status CompileToLocalExecutable(
}
}

std::vector<xla::DynExpr*> dyn_exprs;
std::vector<xla::DExpr> dyn_exprs;
for (int d : shp.dim_sizes()) {
dyn_exprs.push_back(xla::DynExpr::_(d));
dyn_exprs.push_back(xla::DExpr::Const(d));
}
for (int j = 0; j < exp.size(); ++j) {
auto e = DimExprToDynExpr(ExprFromProto(exp[j]).get())->s();
auto e = DimExprToDExpr(ExprFromProto(exp[j]).get()).simplify();
if (e->is_dynamic()) {
dyn_exprs[j] = e;
}
}
shp.set_expressions(dyn_exprs);
shp.set_expressions(std::move(dyn_exprs));
}
}
}
Expand Down Expand Up @@ -691,8 +690,8 @@ absl::Status CompileToLocalExecutable(
if (e->is_dynamic()) {
int64_t old = shp.dim_size(j);
old_vars.push_back({i, j, old});
xla::DynExpr* padded_expr = xla::DynExpr::_(filled_batch);
xla::DynExpr* subst_expr = e->substitute(1, padded_expr)->s();
xla::DExpr padded_expr = xla::DExpr::Const(filled_batch);
xla::DExpr subst_expr = e.substitute(1, padded_expr).simplify();
int64_t new_dim = subst_expr->get_val();
if (new_dim >= 0) {
shp.set_dim(j, new_dim);
Expand Down Expand Up @@ -1266,7 +1265,7 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
if (!xla_shape.IsArray() || xla_shape.expressions().empty()) continue;

for (int dim = 0; dim < xla_shape.expressions().size(); dim++) {
xla::DynExpr* expr = xla_shape.expressions(dim);
const auto& expr = xla_shape.expressions(dim);
if (expr && expr->is_dynamic()) {
int input_idx = comp_result->input_mapping[i] - num_constant_args;
if (input_idx < 0 || input_idx >= ctx->num_inputs()) {
Expand Down
18 changes: 9 additions & 9 deletions tensorflow/compiler/jit/mark_for_compilation_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -753,34 +753,34 @@ std::unique_ptr<DimExpr> ExprFromProto(const ExpressionProto& proto) {
}
}

static xla::DynExpr* DimExprToDynExpr(const DimExpr* e) {
static xla::DExpr DimExprToDExpr(const DimExpr* e) {
switch (e->kind()) {
case DimExpr::Kind::kConstant: {
auto* ac = static_cast<const Constant*>(e);
return xla::DynExpr::_(ac->value());
return xla::DExpr::Const(ac->value());
}
case DimExpr::Kind::kVariable: {
auto* av = static_cast<const Variable*>(e);
return xla::DynExpr::V(av->id()); // Use 1 all the time for now
return xla::DExpr::Var(av->id()); // Use 1 all the time for now
}
case DimExpr::Kind::kAdd: {
auto* ee = static_cast<const ExprAdd*>(e);
return *DimExprToDynExpr(ee->lhs()) + *DimExprToDynExpr(ee->rhs());
return DimExprToDExpr(ee->lhs()) + DimExprToDExpr(ee->rhs());
}
case DimExpr::Kind::kSub: {
auto* ee = static_cast<const ExprSub*>(e);
return *DimExprToDynExpr(ee->lhs()) - *DimExprToDynExpr(ee->rhs());
return DimExprToDExpr(ee->lhs()) - DimExprToDExpr(ee->rhs());
}
case DimExpr::Kind::kMul: {
auto* ee = static_cast<const ExprMul*>(e);
return *DimExprToDynExpr(ee->lhs()) * *DimExprToDynExpr(ee->rhs());
return DimExprToDExpr(ee->lhs()) * DimExprToDExpr(ee->rhs());
}
case DimExpr::Kind::kDiv: {
auto* ee = static_cast<const ExprDiv*>(e);
return *DimExprToDynExpr(ee->lhs()) / *DimExprToDynExpr(ee->rhs());
return DimExprToDExpr(ee->lhs()) / DimExprToDExpr(ee->rhs());
}
}
return nullptr;
return xla::DExpr();
}

// Runs Grappler static inference and logs any ExpressionProto found in output
Expand Down Expand Up @@ -1879,7 +1879,7 @@ absl::Status MarkForCompilationPassImpl::AssignDimVars(void) {
}
for (auto& pDim: (it->second)[output_index]) {
DimExpr * d= pDim.get();
xla::DynExpr * dyn = DimExprToDynExpr(d);
xla::DExpr dyn = DimExprToDExpr(d);
auto new_ids = dyn->get_all_ids();
for (auto id : new_ids) {
cluster->add_dim_var(id);
Expand Down
8 changes: 4 additions & 4 deletions tensorflow/compiler/jit/shape_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,22 @@ absl::Status ShapeHandleToTensorShape(

std::vector<int64_t> dims(context->Rank(handle));
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
std::vector<xla::DynExpr*> dyn_exprs;
std::vector<xla::DExpr> dyn_exprs;
if (flags->tf_xla_enable_dynamic_sizes) {
dyn_exprs.resize(context->Rank(handle));
}
for (int32_t i = 0, end = dims.size(); i < end; ++i) {
dims[i] = context->Value(context->Dim(handle, i));
if (flags->tf_xla_enable_dynamic_sizes) {
auto ratio = context->DynamicRatio(context->Dim(handle, i));
dyn_exprs[i] = ratio > 0 ? (ratio * *xla::DynExpr::V(1))->s()
: xla::DynExpr::_(dims[i]);
dyn_exprs[i] = ratio > 0 ? xla::DExpr::Const(ratio) * xla::DExpr::Var(1)
: xla::DExpr::Const(dims[i]);
}
}
auto status =
PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape);
if (flags->tf_xla_enable_dynamic_sizes) {
shape->set_expressions(dyn_exprs);
shape->set_expressions(std::move(dyn_exprs));
}
return status;
}
Expand Down
12 changes: 6 additions & 6 deletions tensorflow/compiler/jit/xla_launch_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -441,13 +441,13 @@ absl::Status XlaComputationLaunchContext::PopulateOutputs(
bool has_dynamic = false;

for (int dim = 0; dim < subshape.expressions().size(); ++dim) {
auto expr = subshape.expressions(dim);
if (expr != nullptr && expr->is_dynamic()) {
const auto& expr = subshape.expressions(dim);
if (expr && expr->is_dynamic()) {
has_dynamic = true;
VLOG(1) << "Current expression is " << expr;
if (run_options) {
xla::DynExpr* batch_size = xla::DynExpr::_(run_options->batch_size());
xla::DynExpr* subst_expr = expr->substitute(1, batch_size)->s();
xla::DExpr batch_size = xla::DExpr::Const(run_options->batch_size());
xla::DExpr subst_expr = expr.substitute(1, batch_size).simplify();
shape.set_dim(dim, subst_expr->get_val());
} else {
// TODO: Fallback to BatchSizeResource for now. Remove it later.
Expand All @@ -456,9 +456,9 @@ absl::Status XlaComputationLaunchContext::PopulateOutputs(
ScopedStepContainer* step_container = ctx->step_container();
TF_RETURN_IF_ERROR(step_container->Lookup<BatchSizeResource>(
ctx->resource_manager(), BatchSizeResourceName, &bsr));
xla::DynExpr* batch_size = xla::DynExpr::_(bsr->GetBatchSize());
xla::DExpr batch_size = xla::DExpr::Const(bsr->GetBatchSize());
// Just substitute Var(1) for now.
xla::DynExpr* subst_expr = expr->substitute(1, batch_size)->s();
xla::DExpr subst_expr = expr.substitute(1, batch_size).simplify();
shape.set_dim(dim, subst_expr->get_val());
bsr->Unref();
}
Expand Down
10 changes: 5 additions & 5 deletions tensorflow/compiler/tf2xla/kernels/bincount_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,14 @@ class DenseBincountOp : public XlaOpKernel {
auto i_shape =
xla::ShapeUtil::MakeShape(input_xla_type, {input_shape.dimensions()});
auto i = xla::Iota(ctx->builder(), i_shape, 0);
xla::DExpr flattened_expr =
(input_shape.expressions(0) * input_shape.expressions(1)).simplify();
i = xla::Reshape(
i, {input_shape.dimensions(0) * input_shape.dimensions(1), 1},
{(*input_shape.expressions(0) * *input_shape.expressions(1))->s(),
xla::DynExpr::one});
{flattened_expr, xla::DExpr::Const(1)});
auto j = xla::Reshape(
input, {input_shape.dimensions(0) * input_shape.dimensions(1), 1},
{(*input_shape.expressions(0) * *input_shape.expressions(1))->s(),
xla::DynExpr::one});
{flattened_expr, xla::DExpr::Const(1)});
std::vector<xla::XlaOp> iotas_to_concat;
iotas_to_concat.push_back(i);
iotas_to_concat.push_back(j);
Expand All @@ -135,7 +135,7 @@ class DenseBincountOp : public XlaOpKernel {
if (has_weights && !binary_output_) {
weights = xla::Reshape(
weights, {input_shape.dimensions(0) * input_shape.dimensions(1)},
{(*input_shape.expressions(0) * *input_shape.expressions(1))->s()});
{flattened_expr});
updates = weights;
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/tf2xla/kernels/const_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class ConstOp : public XlaOpKernel {
if (has_dynamic) {
std::vector<xla::XlaOp> dimension_constants;
for (int i = 0; i < shape.dims(); ++i) {
if (shape.get_expression(i)->is_dynamic()) {
if (shape.get_expression(i) && shape.get_expression(i)->is_dynamic()) {
int32_t dim_val = static_cast<int32_t>(shape.dim_size(i));
xla::XlaOp scalar_const = xla::ConstantR0<int32_t>(b, dim_val);
xla::ExpressionProto expr_proto;
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ xla::XlaOp TransposeFilterForGroupConvolutionBackpropInput(
new_shape.set_dimensions(num_dims - 1, num_groups);
new_shape.add_dimensions(
filter_shape.dimensions(num_dims - 1) / num_groups,
(*filter_shape.expressions(num_dims - 1) / num_groups)->s());
(filter_shape.expressions(num_dims - 1) / num_groups).simplify());
xla::XlaOp result =
xla::Reshape(filter, new_shape.dimensions(), new_shape.expressions());

Expand Down
Loading