diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 2fb97d66b3ac63..e02f930079215b 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -101,6 +101,8 @@ limitations under the License. namespace tensorflow { namespace { +constexpr char kUserInferredValueContentsAttrName[] = + "_user_inferred_value_contents"; using XlaDeviceCompiler = DeviceCompiler; using PjRtDeviceCompiler = @@ -520,35 +522,35 @@ absl::Status CompileToLocalExecutable( return; } + auto inferred_shape_it = attr_map.find("user_inferred_shape"); + auto inferred_contents_it = + attr_map.find(kUserInferredValueContentsAttrName); bool has_dynamic = false; auto has_dynamic_it = attr_map.find("has_dynamic"); - if (has_dynamic_it == attr_map.end()) { - return; - } - has_dynamic = has_dynamic_it->second.b(); - if (!has_dynamic) { - return; + if (has_dynamic_it != attr_map.end()) { + has_dynamic = has_dynamic_it->second.b(); } - auto inferred_shape_it = attr_map.find("user_inferred_shape"); - if (inferred_shape_it == attr_map.end()) { - VLOG(1) << "XlaCompileOp saw has_dynamic for const arg " - << arg_index << " node=" << node_name - << " but no user_inferred_shape attr"; + if (inferred_contents_it == attr_map.end() && + (!has_dynamic || inferred_shape_it == attr_map.end())) { return; } TensorShapeProto inferred_shape_proto; - inferred_shape_proto = inferred_shape_it->second.shape(); + if (inferred_contents_it != attr_map.end()) { + if (!inferred_shape_proto.ParseFromString( + inferred_contents_it->second.s())) { + return; + } + } else { + inferred_shape_proto = inferred_shape_it->second.shape(); + } TensorShape inferred_shape(inferred_shape_proto); - if (!TensorShapeUtils::IsVector(arg.constant_value.shape()) || - arg.constant_value.NumElements() != inferred_shape.dims()) { - VLOG(1) << "XlaCompileOp const arg " << arg_index - << " node=" << node_name - << " has dynamic shape metadata but tensor shape " - << arg.constant_value.shape().DebugString() - << " does not match inferred rank " << inferred_shape.dims(); + if (!((TensorShapeUtils::IsVector(arg.constant_value.shape()) && + arg.constant_value.NumElements() == inferred_shape.dims()) || + (TensorShapeUtils::IsScalar(arg.constant_value.shape()) && + inferred_shape.dims() == 1))) { return; } @@ -564,18 +566,11 @@ absl::Status CompileToLocalExecutable( } else if (arg.constant_value.dtype() == DT_INT64) { expr.set_constant_value(arg.constant_value.flat()(i)); } else { - VLOG(1) << "XlaCompileOp const arg " << arg_index - << " node=" << node_name - << " has unsupported dtype for inferred shape contents: " - << DataTypeString(arg.constant_value.dtype()); arg.constant_value_expressions.clear(); return; } arg.constant_value_expressions.push_back(std::move(expr)); } - VLOG(1) << "XlaCompileOp recovered " << arg.constant_value_expressions.size() - << " constant_value_expressions for const arg " << arg_index - << " node=" << node_name << " from user_inferred_shape"; }; auto record_dynamic_dim_value = [&](int64_t dim_size, xla::DExpr expr) { if (!saw_dynamic_dim_value) { diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index a911f28246da77..e3650249405aa4 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -35,6 +35,9 @@ limitations under the License. namespace tensorflow { namespace { +constexpr char kUserInferredValueContentsAttrName[] = + "_user_inferred_value_contents"; + template DstT CastTo(SrcT src) { return static_cast(src); @@ -200,6 +203,13 @@ int64_t CountDynamicShapeContents(const TensorShapeProto& shape) { return dynamic_count; } +bool CanAttachContentsFromTensorShapeProto(const TensorShape& tensor_shape, + const TensorShapeProto& contents) { + return (tensor_shape.dims() == 0 && contents.dim_size() == 1) || + (tensor_shape.dims() == 1 && + tensor_shape.dim_size(0) == contents.dim_size()); +} + class ConstOp : public XlaOpKernel { public: explicit ConstOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -219,17 +229,25 @@ class ConstOp : public XlaOpKernel { bool has_dynamic = false; TensorShapeProto inferred_shape_proto; + TensorShapeProto inferred_value_contents_proto; + string inferred_value_contents_serialized; if (GetNodeAttr(ctx->op_kernel().def(), "has_dynamic", &has_dynamic).ok() && - has_dynamic) { - if (GetNodeAttr(ctx->op_kernel().def(), "user_inferred_shape", - &inferred_shape_proto) - .ok()) { - VLOG(1) << "ConstOp recovered dynamic folded-const metadata with " - << "inferred_shape=" << inferred_shape_proto.DebugString() - << " dynamic_exprs=" - << CountDynamicShapeContents(inferred_shape_proto); + has_dynamic && + GetNodeAttr(ctx->op_kernel().def(), "user_inferred_shape", + &inferred_shape_proto) + .ok()) { + } + if (GetNodeAttr(ctx->op_kernel().def(), kUserInferredValueContentsAttrName, + &inferred_value_contents_serialized) + .ok()) { + if (!inferred_value_contents_proto.ParseFromString( + inferred_value_contents_serialized)) { + inferred_value_contents_proto.Clear(); } } + const bool has_contents_proto = inferred_value_contents_proto.dim_size() > 0; + const TensorShapeProto& contents_proto = + has_contents_proto ? inferred_value_contents_proto : inferred_shape_proto; // To avoid blowups for large constants filled with the same value, // recognize that case and emit a scalar broadcast instead. @@ -246,14 +264,10 @@ class ConstOp : public XlaOpKernel { xla::Broadcast(value, shape.dim_sizes(), shape.get_expressions()); XlaExpression output = XlaExpression::XlaOp(broadcast, ctx->expected_output_dtype(0)); - if (has_dynamic && shape.dims() == 1 && - shape.dim_size(0) == inferred_shape_proto.dim_size()) { - VLOG(1) << "ConstOp attaching shape contents through broadcast fast " - << "path with " << shape.dim_size(0) - << " entries and dynamic_exprs=" - << CountDynamicShapeContents(inferred_shape_proto); + if ((has_contents_proto || has_dynamic) && + CanAttachContentsFromTensorShapeProto(shape, contents_proto)) { output.set_contents( - BuildShapeContentsFromTensorShapeProto(inferred_shape_proto)); + BuildShapeContentsFromTensorShapeProto(contents_proto)); } ctx->SetOutputExpression(0, output); return; @@ -264,19 +278,11 @@ class ConstOp : public XlaOpKernel { OP_REQUIRES(ctx, tensor.FromProto(cpu_allocator(), proto_), errors::InvalidArgument("Cannot parse tensor from proto: ", proto_.DebugString())); - if (has_dynamic) { - VLOG(1) << "ConstOp tensor path tensor_shape=" - << tensor.shape().DebugString() << " inferred_rank=" - << inferred_shape_proto.dim_size(); - } XlaExpression output = XlaExpression::Constant(tensor); - if (has_dynamic && tensor.dims() == 1 && - tensor.dim_size(0) == inferred_shape_proto.dim_size()) { - VLOG(1) << "ConstOp attaching shape contents to folded const with " - << tensor.dim_size(0) << " entries and dynamic_exprs=" - << CountDynamicShapeContents(inferred_shape_proto); + if ((has_contents_proto || has_dynamic) && + CanAttachContentsFromTensorShapeProto(tensor.shape(), contents_proto)) { output.set_contents( - BuildShapeContentsFromTensorShapeProto(inferred_shape_proto)); + BuildShapeContentsFromTensorShapeProto(contents_proto)); } ctx->SetOutputExpression(0, output); } diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index 301015eba61fe8..24089b11472b6b 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -2754,6 +2754,7 @@ tf_cc_test( ":direct_session_internal", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", "//tensorflow/cc:sendrecv_ops", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -2769,9 +2770,12 @@ tf_cc_test( "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:concat_op", "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:gather_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:immutable_constant_op", "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:reshape_op", + "//tensorflow/core/kernels:slice_op", "//tensorflow/core/kernels:topk_op", "@eigen_archive//:eigen3", ], diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index e4427877c33cce..ebb050cec35bb3 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -53,6 +53,8 @@ namespace { const char kScopedAllocatorAttrName[] = "_scoped_allocator"; const char kXlaShapeDerivedAttrName[] = "_xla_shape_derived"; +const char kUserInferredValueContentsAttrName[] = + "_user_inferred_value_contents"; bool IsShapeOp(const Node* n); @@ -80,6 +82,420 @@ bool GetShapeFromDirectDynamicSource(const Node* node, GetShapeFromArgNode(node, out_shape); } +bool TryParseSerializedContentsAttr(const AttrSlice& attrs, + TensorShapeProto* out_contents) { + string serialized_contents; + if (!GetNodeAttr(attrs, kUserInferredValueContentsAttrName, + &serialized_contents) + .ok()) { + return false; + } + out_contents->Clear(); + return out_contents->ParseFromString(serialized_contents); +} + +bool TryGetContentsProtoAttr(const AttrSlice& attrs, + TensorShapeProto* out_contents) { + return TryParseSerializedContentsAttr(attrs, out_contents); +} + +bool HasTransitiveDynamicShapeContents( + const Node* node, std::unordered_map* memo) { + auto it = memo->find(node); + if (it != memo->end()) { + return it->second; + } + + TensorShapeProto contents_proto; + if (TryGetContentsProtoAttr(node->attrs(), &contents_proto) && + HasDynamicDimExprs(contents_proto)) { + return (*memo)[node] = true; + } + + bool has_dynamic = false; + TensorShapeProto inferred_shape_proto; + if (GetNodeAttr(node->attrs(), "has_dynamic", &has_dynamic).ok() && + has_dynamic && + GetNodeAttr(node->attrs(), "user_inferred_shape", &inferred_shape_proto) + .ok() && + HasDynamicDimExprs(inferred_shape_proto)) { + return (*memo)[node] = true; + } + + if (GetShapeFromDirectDynamicSource(node, &inferred_shape_proto) || + node->attrs().FindByString(kXlaShapeDerivedAttrName) != nullptr) { + return (*memo)[node] = true; + } + + for (const Edge* edge : node->in_edges()) { + if (edge->IsControlEdge()) continue; + if (HasTransitiveDynamicShapeContents(edge->src(), memo)) { + return (*memo)[node] = true; + } + } + + return (*memo)[node] = false; +} + +bool GetConstTensor(const Node* node, Tensor* tensor) { + if (node == nullptr || !node->IsConstant()) { + return false; + } + const TensorProto* tensor_proto; + if (!GetNodeAttr(node->attrs(), "value", &tensor_proto).ok()) { + return false; + } + DataType dtype; + if (!GetNodeAttr(node->attrs(), "dtype", &dtype).ok()) { + return false; + } + *tensor = Tensor(dtype); + return tensor->FromProto(cpu_allocator(), *tensor_proto); +} + +bool GetInputConstTensor(const Node* node, int input_index, Tensor* tensor) { + const Edge* edge; + if (!node->input_edge(input_index, &edge).ok()) { + return false; + } + return GetConstTensor(edge->src(), tensor); +} + +bool GetTensorIntValues(const Tensor& tensor, std::vector* values) { + values->clear(); + if (tensor.dims() == 0) { + values->reserve(1); + switch (tensor.dtype()) { + case DT_INT32: + values->push_back(tensor.scalar()()); + return true; + case DT_INT64: + values->push_back(tensor.scalar()()); + return true; + default: + return false; + } + } + if (tensor.dims() != 1) { + return false; + } + values->reserve(tensor.NumElements()); + switch (tensor.dtype()) { + case DT_INT32: { + auto flat = tensor.flat(); + for (int i = 0; i < flat.size(); ++i) values->push_back(flat(i)); + return true; + } + case DT_INT64: { + auto flat = tensor.flat(); + for (int i = 0; i < flat.size(); ++i) values->push_back(flat(i)); + return true; + } + default: + return false; + } +} + +void CopyContentAt(const TensorShapeProto& input_contents, int64_t index, + TensorShapeProto* output_contents) { + output_contents->add_dim()->CopyFrom(input_contents.dim(index)); + if (index < input_contents.expressions_size()) { + output_contents->add_expressions()->CopyFrom(input_contents.expressions(index)); + } +} + +void AppendScalarConstantContent(int64_t value, TensorShapeProto* output_contents) { + output_contents->add_dim()->set_size(value); +} + +void AppendScalarContentFromTensor(const Tensor& tensor, + TensorShapeProto* output_contents) { + if (tensor.dtype() == DT_INT32) { + AppendScalarConstantContent(tensor.scalar()(), output_contents); + } else { + AppendScalarConstantContent(tensor.scalar()(), output_contents); + } +} + +ExpressionProto MakeConstantExpressionProto(int64_t value) { + ExpressionProto expr; + expr.set_constant_value(value); + return expr; +} + +ExpressionProto GetContentExpressionProto(const TensorShapeProto& contents, + int64_t index) { + if (index < contents.expressions_size()) { + return contents.expressions(index); + } + return MakeConstantExpressionProto(contents.dim(index).size()); +} + +ExpressionProto MakeMulExpressionProto(ExpressionProto lhs, ExpressionProto rhs) { + ExpressionProto expr; + auto* mul = expr.mutable_mul_node(); + *mul->mutable_lhs() = std::move(lhs); + *mul->mutable_rhs() = std::move(rhs); + return expr; +} + +bool TryGetFoldedValueContents(const Node* node, int output_index, + TensorShapeProto* out_contents) { + out_contents->Clear(); + if (output_index != 0) { + return false; + } + + if (TryParseSerializedContentsAttr(node->attrs(), out_contents)) { + return true; + } + + TensorShapeProto existing_contents; + if (TryGetContentsProtoAttr(node->attrs(), &existing_contents)) { + *out_contents = existing_contents; + return true; + } + + bool has_dynamic = false; + TensorShapeProto user_inferred_shape; + if (GetNodeAttr(node->attrs(), "has_dynamic", &has_dynamic).ok() && + has_dynamic && + GetNodeAttr(node->attrs(), "user_inferred_shape", &user_inferred_shape) + .ok()) { + *out_contents = user_inferred_shape; + return true; + } + + if (GetShapeFromDirectDynamicSource(node, out_contents)) { + return true; + } + + auto recurse_input = [&](int input_index, + TensorShapeProto* input_contents) -> bool { + const Edge* input_edge; + if (!node->input_edge(input_index, &input_edge).ok()) { + return false; + } + return TryGetFoldedValueContents(input_edge->src(), input_edge->src_output(), + input_contents); + }; + + if (node->IsIdentity() || node->type_string() == "Cast") { + return recurse_input(0, out_contents); + } + + TensorShapeProto input_contents; + if (node->type_string() == "Reshape" && recurse_input(0, &input_contents)) { + Tensor shape_tensor; + std::vector shape_dims; + if (!GetInputConstTensor(node, 1, &shape_tensor) || + !GetTensorIntValues(shape_tensor, &shape_dims)) { + return false; + } + if (input_contents.dim_size() == 1 && shape_dims.empty()) { + CopyContentAt(input_contents, 0, out_contents); + return true; + } + if (shape_dims.size() == 1 && input_contents.dim_size() == shape_dims[0]) { + out_contents->CopyFrom(input_contents); + return true; + } + return false; + } + + if (node->type_string() == "Pack") { + for (int i = 0; i < node->num_inputs(); ++i) { + TensorShapeProto scalar_contents; + Tensor scalar_tensor; + if (recurse_input(i, &scalar_contents)) { + if (scalar_contents.dim_size() != 1) { + return false; + } + CopyContentAt(scalar_contents, 0, out_contents); + } else if (GetInputConstTensor(node, i, &scalar_tensor) && + scalar_tensor.dims() == 0 && + (scalar_tensor.dtype() == DT_INT32 || + scalar_tensor.dtype() == DT_INT64)) { + AppendScalarContentFromTensor(scalar_tensor, out_contents); + } else { + return false; + } + } + return true; + } + + if (node->type_string() == "ConcatV2") { + Tensor axis_tensor; + std::vector axis_values; + if (!GetInputConstTensor(node, node->num_inputs() - 1, &axis_tensor) || + !GetTensorIntValues(axis_tensor, &axis_values) || axis_values.size() != 1) { + return false; + } + int64_t axis = axis_values[0]; + if (axis != 0 && axis != -1) { + return false; + } + for (int i = 0; i < node->num_inputs() - 1; ++i) { + TensorShapeProto part_contents; + if (!recurse_input(i, &part_contents)) { + return false; + } + for (int64_t j = 0; j < part_contents.dim_size(); ++j) { + CopyContentAt(part_contents, j, out_contents); + } + } + return true; + } + + if ((node->type_string() == "Gather" || node->type_string() == "GatherV2") && + recurse_input(0, &input_contents)) { + Tensor indices_tensor; + std::vector indices; + if (!GetInputConstTensor(node, 1, &indices_tensor) || + !GetTensorIntValues(indices_tensor, &indices)) { + return false; + } + + int64_t axis = 0; + if (node->type_string() == "GatherV2") { + Tensor axis_tensor; + std::vector axis_values; + if (!GetInputConstTensor(node, 2, &axis_tensor) || + !GetTensorIntValues(axis_tensor, &axis_values) || + axis_values.size() != 1) { + return false; + } + axis = axis_values[0]; + } + + const int64_t params_rank = 1; + if (axis < 0) axis += params_rank; + if (axis != 0) { + return false; + } + + const int64_t rank = input_contents.dim_size(); + for (int64_t index : indices) { + if (index < 0) index += rank; + if (index < 0 || index >= rank) { + return false; + } + CopyContentAt(input_contents, index, out_contents); + } + return true; + } + + if (node->type_string() == "Prod" && recurse_input(0, &input_contents)) { + Tensor reduction_indices_tensor; + std::vector axes; + bool keep_dims = false; + if (!GetInputConstTensor(node, 1, &reduction_indices_tensor) || + !GetTensorIntValues(reduction_indices_tensor, &axes) || + !GetNodeAttr(node->attrs(), "keep_dims", &keep_dims).ok() || + keep_dims || axes.size() != 1 || + (axes[0] != 0 && axes[0] != -1) || input_contents.dim_size() == 0) { + return false; + } + int64_t value = 1; + ExpressionProto expr = GetContentExpressionProto(input_contents, 0); + for (int64_t i = 0; i < input_contents.dim_size(); ++i) { + value *= input_contents.dim(i).size(); + if (i > 0) { + expr = MakeMulExpressionProto(std::move(expr), + GetContentExpressionProto(input_contents, i)); + } + } + out_contents->add_dim()->set_size(value); + out_contents->add_expressions()->Swap(&expr); + return true; + } + + if (node->type_string() == "Slice" && recurse_input(0, &input_contents)) { + Tensor begin_tensor; + Tensor size_tensor; + std::vector begin; + std::vector size; + if (!GetInputConstTensor(node, 1, &begin_tensor) || + !GetInputConstTensor(node, 2, &size_tensor) || + !GetTensorIntValues(begin_tensor, &begin) || + !GetTensorIntValues(size_tensor, &size) || begin.size() != 1 || + size.size() != 1) { + return false; + } + int64_t start = begin[0]; + if (start < 0 || start > input_contents.dim_size()) { + return false; + } + int64_t length = size[0] < 0 ? input_contents.dim_size() - start : size[0]; + if (length < 0 || start + length > input_contents.dim_size()) { + return false; + } + for (int64_t i = 0; i < length; ++i) { + CopyContentAt(input_contents, start + i, out_contents); + } + return true; + } + + if (node->type_string() == "StridedSlice" && + recurse_input(0, &input_contents)) { + Tensor begin_tensor; + Tensor end_tensor; + Tensor strides_tensor; + std::vector begin; + std::vector end; + std::vector strides; + int64_t begin_mask = 0; + int64_t end_mask = 0; + int64_t ellipsis_mask = 0; + int64_t new_axis_mask = 0; + int64_t shrink_axis_mask = 0; + if (!GetInputConstTensor(node, 1, &begin_tensor) || + !GetInputConstTensor(node, 2, &end_tensor) || + !GetInputConstTensor(node, 3, &strides_tensor) || + !GetTensorIntValues(begin_tensor, &begin) || + !GetTensorIntValues(end_tensor, &end) || + !GetTensorIntValues(strides_tensor, &strides) || begin.size() != 1 || + end.size() != 1 || strides.size() != 1 || + !GetNodeAttr(node->attrs(), "begin_mask", &begin_mask).ok() || + !GetNodeAttr(node->attrs(), "end_mask", &end_mask).ok() || + !GetNodeAttr(node->attrs(), "ellipsis_mask", &ellipsis_mask).ok() || + !GetNodeAttr(node->attrs(), "new_axis_mask", &new_axis_mask).ok() || + !GetNodeAttr(node->attrs(), "shrink_axis_mask", &shrink_axis_mask).ok()) { + return false; + } + if (ellipsis_mask != 0 || new_axis_mask != 0) { + return false; + } + const int64_t rank = input_contents.dim_size(); + int64_t stride = strides[0]; + if (stride == 0) { + return false; + } + int64_t start = (begin_mask & 1) ? (stride > 0 ? 0 : rank - 1) : begin[0]; + int64_t stop = (end_mask & 1) ? (stride > 0 ? rank : -1) : end[0]; + if (start < 0) start += rank; + if (stop < 0 && !(end_mask & 1 && stride < 0)) stop += rank; + if (shrink_axis_mask & 1) { + if (start < 0 || start >= rank) { + return false; + } + CopyContentAt(input_contents, start, out_contents); + return true; + } + if (stride < 0) { + return false; + } + start = std::max(0, start); + stop = std::min(rank, stop); + for (int64_t i = start; i < stop; i += stride) { + CopyContentAt(input_contents, i, out_contents); + } + return true; + } + + return false; +} + // For stateless RNGs ops, they are pure but device-dependent. Those ops are not // constant-foldable. static absl::flat_hash_set* kBlockList = @@ -273,16 +689,17 @@ bool IsConstantFoldable( const std::function& consider, int64_t max_constant_size_in_bytes, std::unordered_map>* shape_replacement_map) { + TensorShapeProto exact_contents; + const bool has_exact_contents = TryGetFoldedValueContents(n, 0, &exact_contents); TensorShapeProto dynamic_shape; - if (GetShapeFromDirectDynamicSource(n, &dynamic_shape)) { - VLOG(1) << "Skipping constant folding for dynamic shape-derived node " - << n->name() << " op=" << n->type_string() - << " inferred_shape=" << dynamic_shape.DebugString(); - return false; - } - if (n->attrs().FindByString(kXlaShapeDerivedAttrName) != nullptr) { - VLOG(1) << "Skipping constant folding for shape-derived node " - << n->name() << " op=" << n->type_string(); + const bool has_dynamic = GetShapeFromDirectDynamicSource(n, &dynamic_shape); + const bool is_shape_derived = + n->attrs().FindByString(kXlaShapeDerivedAttrName) != nullptr; + std::unordered_map dynamic_contents_memo; + const bool has_transitive_dynamic_contents = + HasTransitiveDynamicShapeContents(n, &dynamic_contents_memo); + if ((has_dynamic || is_shape_derived || has_transitive_dynamic_contents) && + (!has_exact_contents || n->num_outputs() > 1)) { return false; } if (n->IsConstant()) { @@ -492,6 +909,8 @@ void AddShapeNodeToConstantGraph( TensorShapeProto user_inferred_shape; const bool has_dynamic = GetShapeFromDirectDynamicSource(n, &user_inferred_shape); + TensorShapeProto exact_contents; + const bool has_exact_contents = TryGetFoldedValueContents(n, 0, &exact_contents); std::vector& added = (*node_map)[n]; const string& node_name = n->name(); for (const Tensor& t : shape_replacement_map.at(n)) { @@ -505,6 +924,10 @@ void AddShapeNodeToConstantGraph( builder.Attr("has_dynamic", has_dynamic) .Attr("user_inferred_shape", user_inferred_shape); } + if (has_exact_contents && HasDynamicDimExprs(exact_contents)) { + builder.Attr(kUserInferredValueContentsAttrName, + exact_contents.SerializeAsString()); + } NodeDef def; CHECK(builder.Finalize(&def).ok()); Node* constant_node; @@ -595,6 +1018,12 @@ bool ReplaceTensorWithConstant( ? DeviceType{partition_device->device_type()} : DEVICE_CPU; if (partition_device && device_type != DEVICE_CPU) { + // Constant folding replaces one output edge-set at a time. Be + // conservative for non-CPU multi-output ops, since partially replacing a + // node can violate per-output placement or memory-type assumptions. + if (tensor.first->num_outputs() > 1) { + return false; + } MemoryTypeVector input_mvec; MemoryTypeVector output_mvec; if (!MemoryTypesForNode(graph->op_registry(), device_type, @@ -626,6 +1055,24 @@ bool ReplaceTensorWithConstant( TensorShapeProto user_inferred_shape; const bool has_dynamic = GetShapeFromDirectDynamicSource(tensor.first, &user_inferred_shape); + const bool is_shape_derived = + tensor.first->attrs().FindByString(kXlaShapeDerivedAttrName) != nullptr; + if (tensor.second != 0 && (has_dynamic || is_shape_derived)) { + VLOG(1) << "Skipping replacement of " << tensor.first->name() << " :: " + << tensor.second + << " because symbolic content preservation is only supported for " + << "single-output replacements"; + return false; + } + TensorShapeProto exact_contents; + const bool has_exact_contents = + TryGetFoldedValueContents(tensor.first, tensor.second, &exact_contents); + if ((has_dynamic || is_shape_derived) && !has_exact_contents) { + VLOG(1) << "Skipping replacement of " << tensor.first->name() << " :: " + << tensor.second + << " because constant folding could not preserve symbolic contents"; + return false; + } Node* constant_node; auto builder = NodeDefBuilder(generate_new_name(graph, node_name), "Const") .Attr("dtype", constant.dtype()) @@ -634,6 +1081,11 @@ bool ReplaceTensorWithConstant( builder.Attr("has_dynamic", has_dynamic) .Attr("user_inferred_shape", user_inferred_shape); } + if (has_exact_contents && HasDynamicDimExprs(exact_contents)) { + builder.Attr("has_dynamic", true) + .Attr(kUserInferredValueContentsAttrName, + exact_contents.SerializeAsString()); + } if (partition_device) { builder.Device(partition_device->name()); } diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc index 481a85add4893c..18ae0e74a007b6 100644 --- a/tensorflow/core/common_runtime/constant_folding_test.cc +++ b/tensorflow/core/common_runtime/constant_folding_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/cc/ops/array_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/nn_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -32,6 +33,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape_expr.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/node_builder.h" @@ -45,6 +47,15 @@ limitations under the License. namespace tensorflow { namespace { +TensorShapeProto MakeDynamicShapeProto977x16() { + TensorShapeProto proto; + proto.add_dim()->set_size(977); + proto.add_dim()->set_size(16); + proto.add_expressions()->set_variable_id(0); + proto.add_expressions()->set_constant_value(16); + return proto; +} + class ConstantFoldingTest : public ::testing::Test { protected: template @@ -634,6 +645,196 @@ TEST_F(ConstantFoldingTest, ConstShapeKnown) { } } +TEST_F(ConstantFoldingTest, FoldShapeFromDynamicArgPreservesContents) { + Graph g(OpRegistry::Global()); + Scope s = Scope::NewRootScope(); + auto arg = ops::_Arg(s.WithOpName("arg"), DT_FLOAT, 0); + auto shape = ops::Shape(s.WithOpName("shape"), arg); + auto send = ops::_Send(s.WithOpName("send"), shape, "send", "sender", 0, + "receiver"); + TF_ASSERT_OK(s.ToGraph(&g)); + + std::unordered_map index_by_name = g.BuildNodeNameIndex(); + Node* arg_node = index_by_name.at("arg"); + arg_node->AddAttr("_output_shapes", + std::vector{MakeDynamicShapeProto977x16()}); + + PartialTensorShape partial_shape({977, 16}); + std::unordered_map> shape_map; + shape_map[arg_node->name()].push_back(partial_shape); + + ConstantFoldingOptions opts; + opts.shape_map = &shape_map; + bool was_mutated = false; + TF_ASSERT_OK( + ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated)); + + index_by_name = g.BuildNodeNameIndex(); + Node* send_node = index_by_name.at("send"); + const Edge* send_input = nullptr; + TF_ASSERT_OK(send_node->input_edge(0, &send_input)); + Node* folded = send_input->src(); + ExpectNodeEqual(folded, {977, 16}, {2}); + + // This test checks the stable contract we care about: after folding + // Shape(arg), the replacement Const still carries symbolic contents. + string serialized_contents_proto; + TF_ASSERT_OK(GetNodeAttr(folded->attrs(), + "_user_inferred_value_contents", + &serialized_contents_proto)); + TensorShapeProto contents_proto; + ASSERT_TRUE(contents_proto.ParseFromString(serialized_contents_proto)); + ASSERT_EQ(contents_proto.expressions_size(), 2); + EXPECT_TRUE(IsDynamicDimExpr(contents_proto.expressions(0))); + EXPECT_FALSE(IsDynamicDimExpr(contents_proto.expressions(1))); + EXPECT_EQ(contents_proto.dim(0).size(), 977); + EXPECT_EQ(contents_proto.dim(1).size(), 16); +} + +TEST_F(ConstantFoldingTest, FoldSliceOfDynamicShapePreservesContents) { + Graph g(OpRegistry::Global()); + Scope s = Scope::NewRootScope(); + auto arg = ops::_Arg(s.WithOpName("arg"), DT_FLOAT, 0); + auto shape = ops::Shape(s.WithOpName("shape"), arg); + auto begin = ops::Const(s.WithOpName("begin"), {0}, {1}); + auto size = ops::Const(s.WithOpName("size"), {1}, {1}); + auto slice = ops::Slice(s.WithOpName("slice"), shape, begin, size); + auto send = ops::_Send(s.WithOpName("send"), slice, "send", "sender", 0, + "receiver"); + TF_ASSERT_OK(s.ToGraph(&g)); + + std::unordered_map index_by_name = g.BuildNodeNameIndex(); + Node* arg_node = index_by_name.at("arg"); + arg_node->AddAttr("_output_shapes", + std::vector{MakeDynamicShapeProto977x16()}); + + PartialTensorShape partial_shape({977, 16}); + std::unordered_map> shape_map; + shape_map[arg_node->name()].push_back(partial_shape); + + ConstantFoldingOptions opts; + opts.shape_map = &shape_map; + bool was_mutated = false; + TF_ASSERT_OK( + ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated)); + + index_by_name = g.BuildNodeNameIndex(); + Node* send_node = index_by_name.at("send"); + const Edge* send_input = nullptr; + TF_ASSERT_OK(send_node->input_edge(0, &send_input)); + Node* folded = send_input->src(); + ExpectNodeEqual(folded, {977}, {1}); + + // Folding Slice(Shape(arg), [0], [1]) should preserve the selected symbolic + // content, not just the concrete value 977. + string serialized_contents_proto; + TF_ASSERT_OK(GetNodeAttr(folded->attrs(), + "_user_inferred_value_contents", + &serialized_contents_proto)); + TensorShapeProto contents_proto; + ASSERT_TRUE(contents_proto.ParseFromString(serialized_contents_proto)); + ASSERT_EQ(contents_proto.expressions_size(), 1); + EXPECT_TRUE(IsDynamicDimExpr(contents_proto.expressions(0))); + EXPECT_EQ(contents_proto.dim(0).size(), 977); +} + +TEST_F(ConstantFoldingTest, FoldGatherOfDynamicShapePreservesContents) { + Graph g(OpRegistry::Global()); + Scope s = Scope::NewRootScope(); + auto arg = ops::_Arg(s.WithOpName("arg"), DT_FLOAT, 0); + auto shape = ops::Shape(s.WithOpName("shape"), arg); + auto index = ops::Const(s.WithOpName("index"), 0); + auto gather = ops::GatherV2(s.WithOpName("gather"), shape, index, + ops::Const(s.WithOpName("axis"), 0)); + auto send = ops::_Send(s.WithOpName("send"), gather, "send", "sender", 0, + "receiver"); + TF_ASSERT_OK(s.ToGraph(&g)); + + std::unordered_map index_by_name = g.BuildNodeNameIndex(); + Node* arg_node = index_by_name.at("arg"); + arg_node->AddAttr("_output_shapes", + std::vector{MakeDynamicShapeProto977x16()}); + + PartialTensorShape partial_shape({977, 16}); + std::unordered_map> shape_map; + shape_map[arg_node->name()].push_back(partial_shape); + + ConstantFoldingOptions opts; + opts.shape_map = &shape_map; + bool was_mutated = false; + TF_ASSERT_OK( + ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated)); + + index_by_name = g.BuildNodeNameIndex(); + Node* send_node = index_by_name.at("send"); + const Edge* send_input = nullptr; + TF_ASSERT_OK(send_node->input_edge(0, &send_input)); + Node* folded = send_input->src(); + ExpectNodeEqual(folded, {977}, {}); + + // Folding Gather(Shape(arg), 0, axis=0) should preserve the selected + // symbolic content, not just the concrete scalar value 977. + string serialized_contents_proto; + TF_ASSERT_OK(GetNodeAttr(folded->attrs(), + "_user_inferred_value_contents", + &serialized_contents_proto)); + TensorShapeProto contents_proto; + ASSERT_TRUE(contents_proto.ParseFromString(serialized_contents_proto)); + ASSERT_EQ(contents_proto.expressions_size(), 1); + EXPECT_TRUE(IsDynamicDimExpr(contents_proto.expressions(0))); + EXPECT_EQ(contents_proto.dim(0).size(), 977); +} + +TEST_F(ConstantFoldingTest, FoldReshapeOfDynamicShapePreservesContents) { + Graph g(OpRegistry::Global()); + Scope s = Scope::NewRootScope(); + auto arg = ops::_Arg(s.WithOpName("arg"), DT_FLOAT, 0); + auto shape = ops::Shape(s.WithOpName("shape"), arg); + auto begin = ops::Const(s.WithOpName("begin"), {0}, {1}); + auto size = ops::Const(s.WithOpName("size"), {1}, {1}); + auto slice = ops::Slice(s.WithOpName("slice"), shape, begin, size); + auto scalar_shape = ops::Const(s.WithOpName("scalar_shape"), {}); + auto reshape = + ops::Reshape(s.WithOpName("reshape"), slice, scalar_shape); + auto send = ops::_Send(s.WithOpName("send"), reshape, "send", "sender", 0, + "receiver"); + TF_ASSERT_OK(s.ToGraph(&g)); + + std::unordered_map index_by_name = g.BuildNodeNameIndex(); + Node* arg_node = index_by_name.at("arg"); + arg_node->AddAttr("_output_shapes", + std::vector{MakeDynamicShapeProto977x16()}); + + PartialTensorShape partial_shape({977, 16}); + std::unordered_map> shape_map; + shape_map[arg_node->name()].push_back(partial_shape); + + ConstantFoldingOptions opts; + opts.shape_map = &shape_map; + bool was_mutated = false; + TF_ASSERT_OK( + ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated)); + + index_by_name = g.BuildNodeNameIndex(); + Node* send_node = index_by_name.at("send"); + const Edge* send_input = nullptr; + TF_ASSERT_OK(send_node->input_edge(0, &send_input)); + Node* folded = send_input->src(); + ExpectNodeEqual(folded, {977}, {}); + + // Folding Reshape(Slice(Shape(arg), [0], [1]), []) should preserve the + // selected symbolic content when the shape-vector result is scalarized. + string serialized_contents_proto; + TF_ASSERT_OK(GetNodeAttr(folded->attrs(), + "_user_inferred_value_contents", + &serialized_contents_proto)); + TensorShapeProto contents_proto; + ASSERT_TRUE(contents_proto.ParseFromString(serialized_contents_proto)); + ASSERT_EQ(contents_proto.expressions_size(), 1); + EXPECT_TRUE(IsDynamicDimExpr(contents_proto.expressions(0))); + EXPECT_EQ(contents_proto.dim(0).size(), 977); +} + TEST_F(ConstantFoldingTest, NoReplacePartialOutput) { Graph g(OpRegistry::Global()); {