Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 21 additions & 26 deletions tensorflow/compiler/jit/kernels/xla_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ limitations under the License.
namespace tensorflow {

namespace {
constexpr char kUserInferredValueContentsAttrName[] =
"_user_inferred_value_contents";
using XlaDeviceCompiler =
DeviceCompiler<xla::LocalExecutable, xla::LocalClient>;
using PjRtDeviceCompiler =
Expand Down Expand Up @@ -520,35 +522,35 @@ absl::Status CompileToLocalExecutable(
return;
}

auto inferred_shape_it = attr_map.find("user_inferred_shape");
auto inferred_contents_it =
attr_map.find(kUserInferredValueContentsAttrName);
bool has_dynamic = false;
auto has_dynamic_it = attr_map.find("has_dynamic");
if (has_dynamic_it == attr_map.end()) {
return;
}
has_dynamic = has_dynamic_it->second.b();
if (!has_dynamic) {
return;
if (has_dynamic_it != attr_map.end()) {
has_dynamic = has_dynamic_it->second.b();
}

auto inferred_shape_it = attr_map.find("user_inferred_shape");
if (inferred_shape_it == attr_map.end()) {
VLOG(1) << "XlaCompileOp saw has_dynamic for const arg "
<< arg_index << " node=" << node_name
<< " but no user_inferred_shape attr";
if (inferred_contents_it == attr_map.end() &&
(!has_dynamic || inferred_shape_it == attr_map.end())) {
return;
}

TensorShapeProto inferred_shape_proto;
inferred_shape_proto = inferred_shape_it->second.shape();
if (inferred_contents_it != attr_map.end()) {
if (!inferred_shape_proto.ParseFromString(
inferred_contents_it->second.s())) {
return;
}
} else {
inferred_shape_proto = inferred_shape_it->second.shape();
}

TensorShape inferred_shape(inferred_shape_proto);
if (!TensorShapeUtils::IsVector(arg.constant_value.shape()) ||
arg.constant_value.NumElements() != inferred_shape.dims()) {
VLOG(1) << "XlaCompileOp const arg " << arg_index
<< " node=" << node_name
<< " has dynamic shape metadata but tensor shape "
<< arg.constant_value.shape().DebugString()
<< " does not match inferred rank " << inferred_shape.dims();
if (!((TensorShapeUtils::IsVector(arg.constant_value.shape()) &&
arg.constant_value.NumElements() == inferred_shape.dims()) ||
(TensorShapeUtils::IsScalar(arg.constant_value.shape()) &&
inferred_shape.dims() == 1))) {
return;
}

Expand All @@ -564,18 +566,11 @@ absl::Status CompileToLocalExecutable(
} else if (arg.constant_value.dtype() == DT_INT64) {
expr.set_constant_value(arg.constant_value.flat<int64_t>()(i));
} else {
VLOG(1) << "XlaCompileOp const arg " << arg_index
<< " node=" << node_name
<< " has unsupported dtype for inferred shape contents: "
<< DataTypeString(arg.constant_value.dtype());
arg.constant_value_expressions.clear();
return;
}
arg.constant_value_expressions.push_back(std::move(expr));
}
VLOG(1) << "XlaCompileOp recovered " << arg.constant_value_expressions.size()
<< " constant_value_expressions for const arg " << arg_index
<< " node=" << node_name << " from user_inferred_shape";
};
auto record_dynamic_dim_value = [&](int64_t dim_size, xla::DExpr expr) {
if (!saw_dynamic_dim_value) {
Expand Down
58 changes: 32 additions & 26 deletions tensorflow/compiler/tf2xla/kernels/const_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ limitations under the License.
namespace tensorflow {
namespace {

constexpr char kUserInferredValueContentsAttrName[] =
"_user_inferred_value_contents";

template <typename DstT, typename SrcT>
DstT CastTo(SrcT src) {
return static_cast<DstT>(src);
Expand Down Expand Up @@ -200,6 +203,13 @@ int64_t CountDynamicShapeContents(const TensorShapeProto& shape) {
return dynamic_count;
}

bool CanAttachContentsFromTensorShapeProto(const TensorShape& tensor_shape,
const TensorShapeProto& contents) {
return (tensor_shape.dims() == 0 && contents.dim_size() == 1) ||
(tensor_shape.dims() == 1 &&
tensor_shape.dim_size(0) == contents.dim_size());
}

class ConstOp : public XlaOpKernel {
public:
explicit ConstOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
Expand All @@ -219,17 +229,25 @@ class ConstOp : public XlaOpKernel {

bool has_dynamic = false;
TensorShapeProto inferred_shape_proto;
TensorShapeProto inferred_value_contents_proto;
string inferred_value_contents_serialized;
if (GetNodeAttr(ctx->op_kernel().def(), "has_dynamic", &has_dynamic).ok() &&
has_dynamic) {
if (GetNodeAttr(ctx->op_kernel().def(), "user_inferred_shape",
&inferred_shape_proto)
.ok()) {
VLOG(1) << "ConstOp recovered dynamic folded-const metadata with "
<< "inferred_shape=" << inferred_shape_proto.DebugString()
<< " dynamic_exprs="
<< CountDynamicShapeContents(inferred_shape_proto);
has_dynamic &&
GetNodeAttr(ctx->op_kernel().def(), "user_inferred_shape",
&inferred_shape_proto)
.ok()) {
}
if (GetNodeAttr(ctx->op_kernel().def(), kUserInferredValueContentsAttrName,
&inferred_value_contents_serialized)
.ok()) {
if (!inferred_value_contents_proto.ParseFromString(
inferred_value_contents_serialized)) {
inferred_value_contents_proto.Clear();
}
}
const bool has_contents_proto = inferred_value_contents_proto.dim_size() > 0;
const TensorShapeProto& contents_proto =
has_contents_proto ? inferred_value_contents_proto : inferred_shape_proto;

// To avoid blowups for large constants filled with the same value,
// recognize that case and emit a scalar broadcast instead.
Expand All @@ -246,14 +264,10 @@ class ConstOp : public XlaOpKernel {
xla::Broadcast(value, shape.dim_sizes(), shape.get_expressions());
XlaExpression output =
XlaExpression::XlaOp(broadcast, ctx->expected_output_dtype(0));
if (has_dynamic && shape.dims() == 1 &&
shape.dim_size(0) == inferred_shape_proto.dim_size()) {
VLOG(1) << "ConstOp attaching shape contents through broadcast fast "
<< "path with " << shape.dim_size(0)
<< " entries and dynamic_exprs="
<< CountDynamicShapeContents(inferred_shape_proto);
if ((has_contents_proto || has_dynamic) &&
CanAttachContentsFromTensorShapeProto(shape, contents_proto)) {
output.set_contents(
BuildShapeContentsFromTensorShapeProto(inferred_shape_proto));
BuildShapeContentsFromTensorShapeProto(contents_proto));
}
ctx->SetOutputExpression(0, output);
return;
Expand All @@ -264,19 +278,11 @@ class ConstOp : public XlaOpKernel {
OP_REQUIRES(ctx, tensor.FromProto(cpu_allocator(), proto_),
errors::InvalidArgument("Cannot parse tensor from proto: ",
proto_.DebugString()));
if (has_dynamic) {
VLOG(1) << "ConstOp tensor path tensor_shape="
<< tensor.shape().DebugString() << " inferred_rank="
<< inferred_shape_proto.dim_size();
}
XlaExpression output = XlaExpression::Constant(tensor);
if (has_dynamic && tensor.dims() == 1 &&
tensor.dim_size(0) == inferred_shape_proto.dim_size()) {
VLOG(1) << "ConstOp attaching shape contents to folded const with "
<< tensor.dim_size(0) << " entries and dynamic_exprs="
<< CountDynamicShapeContents(inferred_shape_proto);
if ((has_contents_proto || has_dynamic) &&
CanAttachContentsFromTensorShapeProto(tensor.shape(), contents_proto)) {
output.set_contents(
BuildShapeContentsFromTensorShapeProto(inferred_shape_proto));
BuildShapeContentsFromTensorShapeProto(contents_proto));
}
ctx->SetOutputExpression(0, output);
}
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/common_runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2754,6 +2754,7 @@ tf_cc_test(
":direct_session_internal",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
Expand All @@ -2769,9 +2770,12 @@ tf_cc_test(
"//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:concat_op",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:gather_op",
"//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:immutable_constant_op",
"//tensorflow/core/kernels:matmul_op",
"//tensorflow/core/kernels:reshape_op",
"//tensorflow/core/kernels:slice_op",
"//tensorflow/core/kernels:topk_op",
"@eigen_archive//:eigen3",
],
Expand Down
Loading