From e0d8b3c7466e1a2fb6a6e456da77a6cc0473b541 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Mon, 27 Apr 2026 11:48:24 +0100 Subject: [PATCH 1/8] Preserve symbolic contents during partial constant folding --- tensorflow/compiler/jit/kernels/xla_ops.cc | 30 +- .../compiler/tf2xla/kernels/const_op.cc | 55 ++- .../core/common_runtime/constant_folding.cc | 402 +++++++++++++++++- 3 files changed, 442 insertions(+), 45 deletions(-) diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 2fb97d66b3ac63..c8b0d2eee913d6 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -520,30 +520,32 @@ absl::Status CompileToLocalExecutable( return; } + auto inferred_shape_it = attr_map.find("user_inferred_shape"); + auto inferred_contents_it = + attr_map.find("user_inferred_value_contents"); bool has_dynamic = false; auto has_dynamic_it = attr_map.find("has_dynamic"); - if (has_dynamic_it == attr_map.end()) { - return; - } - has_dynamic = has_dynamic_it->second.b(); - if (!has_dynamic) { - return; + if (has_dynamic_it != attr_map.end()) { + has_dynamic = has_dynamic_it->second.b(); } - auto inferred_shape_it = attr_map.find("user_inferred_shape"); - if (inferred_shape_it == attr_map.end()) { - VLOG(1) << "XlaCompileOp saw has_dynamic for const arg " - << arg_index << " node=" << node_name - << " but no user_inferred_shape attr"; + if (inferred_contents_it == attr_map.end() && + (!has_dynamic || inferred_shape_it == attr_map.end())) { return; } TensorShapeProto inferred_shape_proto; - inferred_shape_proto = inferred_shape_it->second.shape(); + if (inferred_contents_it != attr_map.end()) { + inferred_shape_proto = inferred_contents_it->second.shape(); + } else { + inferred_shape_proto = inferred_shape_it->second.shape(); + } TensorShape inferred_shape(inferred_shape_proto); - if (!TensorShapeUtils::IsVector(arg.constant_value.shape()) || - arg.constant_value.NumElements() != inferred_shape.dims()) { + if (!((TensorShapeUtils::IsVector(arg.constant_value.shape()) && + arg.constant_value.NumElements() == inferred_shape.dims()) || + (TensorShapeUtils::IsScalar(arg.constant_value.shape()) && + inferred_shape.dims() == 1))) { VLOG(1) << "XlaCompileOp const arg " << arg_index << " node=" << node_name << " has dynamic shape metadata but tensor shape " diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index a911f28246da77..cc3b3cff365895 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -200,6 +200,13 @@ int64_t CountDynamicShapeContents(const TensorShapeProto& shape) { return dynamic_count; } +bool CanAttachContentsFromTensorShapeProto(const TensorShape& tensor_shape, + const TensorShapeProto& contents) { + return (tensor_shape.dims() == 0 && contents.dim_size() == 1) || + (tensor_shape.dims() == 1 && + tensor_shape.dim_size(0) == contents.dim_size()); +} + class ConstOp : public XlaOpKernel { public: explicit ConstOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -219,17 +226,23 @@ class ConstOp : public XlaOpKernel { bool has_dynamic = false; TensorShapeProto inferred_shape_proto; + TensorShapeProto inferred_value_contents_proto; if (GetNodeAttr(ctx->op_kernel().def(), "has_dynamic", &has_dynamic).ok() && - has_dynamic) { - if (GetNodeAttr(ctx->op_kernel().def(), "user_inferred_shape", - &inferred_shape_proto) - .ok()) { - VLOG(1) << "ConstOp recovered dynamic folded-const metadata with " - << "inferred_shape=" << inferred_shape_proto.DebugString() - << " dynamic_exprs=" - << CountDynamicShapeContents(inferred_shape_proto); - } + has_dynamic && + GetNodeAttr(ctx->op_kernel().def(), "user_inferred_shape", + &inferred_shape_proto) + .ok()) { + VLOG(1) << "ConstOp recovered dynamic folded-const metadata with " + << "inferred_shape=" << inferred_shape_proto.DebugString() + << " dynamic_exprs=" + << CountDynamicShapeContents(inferred_shape_proto); } + GetNodeAttr(ctx->op_kernel().def(), "user_inferred_value_contents", + &inferred_value_contents_proto) + .IgnoreError(); + const bool has_contents_proto = inferred_value_contents_proto.dim_size() > 0; + const TensorShapeProto& contents_proto = + has_contents_proto ? inferred_value_contents_proto : inferred_shape_proto; // To avoid blowups for large constants filled with the same value, // recognize that case and emit a scalar broadcast instead. @@ -246,14 +259,14 @@ class ConstOp : public XlaOpKernel { xla::Broadcast(value, shape.dim_sizes(), shape.get_expressions()); XlaExpression output = XlaExpression::XlaOp(broadcast, ctx->expected_output_dtype(0)); - if (has_dynamic && shape.dims() == 1 && - shape.dim_size(0) == inferred_shape_proto.dim_size()) { + if ((has_contents_proto || has_dynamic) && + CanAttachContentsFromTensorShapeProto(shape, contents_proto)) { VLOG(1) << "ConstOp attaching shape contents through broadcast fast " - << "path with " << shape.dim_size(0) + << "path with " << shape.num_elements() << " entries and dynamic_exprs=" - << CountDynamicShapeContents(inferred_shape_proto); + << CountDynamicShapeContents(contents_proto); output.set_contents( - BuildShapeContentsFromTensorShapeProto(inferred_shape_proto)); + BuildShapeContentsFromTensorShapeProto(contents_proto)); } ctx->SetOutputExpression(0, output); return; @@ -264,19 +277,19 @@ class ConstOp : public XlaOpKernel { OP_REQUIRES(ctx, tensor.FromProto(cpu_allocator(), proto_), errors::InvalidArgument("Cannot parse tensor from proto: ", proto_.DebugString())); - if (has_dynamic) { + if (has_contents_proto || has_dynamic) { VLOG(1) << "ConstOp tensor path tensor_shape=" << tensor.shape().DebugString() << " inferred_rank=" - << inferred_shape_proto.dim_size(); + << contents_proto.dim_size(); } XlaExpression output = XlaExpression::Constant(tensor); - if (has_dynamic && tensor.dims() == 1 && - tensor.dim_size(0) == inferred_shape_proto.dim_size()) { + if ((has_contents_proto || has_dynamic) && + CanAttachContentsFromTensorShapeProto(tensor.shape(), contents_proto)) { VLOG(1) << "ConstOp attaching shape contents to folded const with " - << tensor.dim_size(0) << " entries and dynamic_exprs=" - << CountDynamicShapeContents(inferred_shape_proto); + << tensor.NumElements() << " entries and dynamic_exprs=" + << CountDynamicShapeContents(contents_proto); output.set_contents( - BuildShapeContentsFromTensorShapeProto(inferred_shape_proto)); + BuildShapeContentsFromTensorShapeProto(contents_proto)); } ctx->SetOutputExpression(0, output); } diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index e4427877c33cce..681ea1a280b77a 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -53,6 +53,7 @@ namespace { const char kScopedAllocatorAttrName[] = "_scoped_allocator"; const char kXlaShapeDerivedAttrName[] = "_xla_shape_derived"; +const char kUserInferredValueContentsAttrName[] = "user_inferred_value_contents"; bool IsShapeOp(const Node* n); @@ -80,6 +81,363 @@ bool GetShapeFromDirectDynamicSource(const Node* node, GetShapeFromArgNode(node, out_shape); } +bool GetConstTensor(const Node* node, Tensor* tensor) { + if (node == nullptr || !node->IsConstant()) { + return false; + } + const TensorProto* tensor_proto; + if (!GetNodeAttr(node->attrs(), "value", &tensor_proto).ok()) { + return false; + } + DataType dtype; + if (!GetNodeAttr(node->attrs(), "dtype", &dtype).ok()) { + return false; + } + *tensor = Tensor(dtype); + return tensor->FromProto(cpu_allocator(), *tensor_proto); +} + +bool GetInputConstTensor(const Node* node, int input_index, Tensor* tensor) { + const Edge* edge; + if (!node->input_edge(input_index, &edge).ok()) { + return false; + } + return GetConstTensor(edge->src(), tensor); +} + +bool GetTensorIntValues(const Tensor& tensor, std::vector* values) { + values->clear(); + if (tensor.dims() == 0) { + values->reserve(1); + switch (tensor.dtype()) { + case DT_INT32: + values->push_back(tensor.scalar()()); + return true; + case DT_INT64: + values->push_back(tensor.scalar()()); + return true; + default: + return false; + } + } + if (tensor.dims() != 1) { + return false; + } + values->reserve(tensor.NumElements()); + switch (tensor.dtype()) { + case DT_INT32: { + auto flat = tensor.flat(); + for (int i = 0; i < flat.size(); ++i) values->push_back(flat(i)); + return true; + } + case DT_INT64: { + auto flat = tensor.flat(); + for (int i = 0; i < flat.size(); ++i) values->push_back(flat(i)); + return true; + } + default: + return false; + } +} + +void CopyContentAt(const TensorShapeProto& input_contents, int64_t index, + TensorShapeProto* output_contents) { + output_contents->add_dim()->CopyFrom(input_contents.dim(index)); + if (index < input_contents.expressions_size()) { + output_contents->add_expressions()->CopyFrom(input_contents.expressions(index)); + } +} + +void AppendScalarConstantContent(int64_t value, TensorShapeProto* output_contents) { + output_contents->add_dim()->set_size(value); +} + +void AppendScalarContentFromTensor(const Tensor& tensor, + TensorShapeProto* output_contents) { + if (tensor.dtype() == DT_INT32) { + AppendScalarConstantContent(tensor.scalar()(), output_contents); + } else { + AppendScalarConstantContent(tensor.scalar()(), output_contents); + } +} + +ExpressionProto MakeConstantExpressionProto(int64_t value) { + ExpressionProto expr; + expr.set_constant_value(value); + return expr; +} + +ExpressionProto GetContentExpressionProto(const TensorShapeProto& contents, + int64_t index) { + if (index < contents.expressions_size()) { + return contents.expressions(index); + } + return MakeConstantExpressionProto(contents.dim(index).size()); +} + +ExpressionProto MakeMulExpressionProto(ExpressionProto lhs, ExpressionProto rhs) { + ExpressionProto expr; + auto* mul = expr.mutable_mul_node(); + *mul->mutable_lhs() = std::move(lhs); + *mul->mutable_rhs() = std::move(rhs); + return expr; +} + +bool TryGetFoldedValueContents(const Node* node, int output_index, + TensorShapeProto* out_contents) { + out_contents->Clear(); + if (output_index != 0) { + return false; + } + + TensorShapeProto existing_contents; + if (GetNodeAttr(node->attrs(), kUserInferredValueContentsAttrName, + &existing_contents) + .ok()) { + *out_contents = existing_contents; + return true; + } + + bool has_dynamic = false; + TensorShapeProto user_inferred_shape; + if (GetNodeAttr(node->attrs(), "has_dynamic", &has_dynamic).ok() && + has_dynamic && + GetNodeAttr(node->attrs(), "user_inferred_shape", &user_inferred_shape) + .ok()) { + *out_contents = user_inferred_shape; + return true; + } + + if (GetShapeFromDirectDynamicSource(node, out_contents)) { + return true; + } + + auto recurse_input = [&](int input_index, + TensorShapeProto* input_contents) -> bool { + const Edge* input_edge; + if (!node->input_edge(input_index, &input_edge).ok()) { + return false; + } + return TryGetFoldedValueContents(input_edge->src(), input_edge->src_output(), + input_contents); + }; + + if (node->IsIdentity() || node->type_string() == "Cast") { + return recurse_input(0, out_contents); + } + + TensorShapeProto input_contents; + if (node->type_string() == "Reshape" && recurse_input(0, &input_contents)) { + Tensor shape_tensor; + std::vector shape_dims; + if (!GetInputConstTensor(node, 1, &shape_tensor) || + !GetTensorIntValues(shape_tensor, &shape_dims)) { + return false; + } + if (input_contents.dim_size() == 1 && shape_dims.empty()) { + CopyContentAt(input_contents, 0, out_contents); + return true; + } + if (shape_dims.size() == 1 && input_contents.dim_size() == shape_dims[0]) { + out_contents->CopyFrom(input_contents); + return true; + } + return false; + } + + if (node->type_string() == "Pack") { + for (int i = 0; i < node->num_inputs(); ++i) { + TensorShapeProto scalar_contents; + Tensor scalar_tensor; + if (recurse_input(i, &scalar_contents)) { + if (scalar_contents.dim_size() != 1) { + return false; + } + CopyContentAt(scalar_contents, 0, out_contents); + } else if (GetInputConstTensor(node, i, &scalar_tensor) && + scalar_tensor.dims() == 0 && + (scalar_tensor.dtype() == DT_INT32 || + scalar_tensor.dtype() == DT_INT64)) { + AppendScalarContentFromTensor(scalar_tensor, out_contents); + } else { + return false; + } + } + return true; + } + + if (node->type_string() == "ConcatV2") { + Tensor axis_tensor; + std::vector axis_values; + if (!GetInputConstTensor(node, node->num_inputs() - 1, &axis_tensor) || + !GetTensorIntValues(axis_tensor, &axis_values) || axis_values.size() != 1) { + return false; + } + int64_t axis = axis_values[0]; + if (axis != 0 && axis != -1) { + return false; + } + for (int i = 0; i < node->num_inputs() - 1; ++i) { + TensorShapeProto part_contents; + if (!recurse_input(i, &part_contents)) { + return false; + } + for (int64_t j = 0; j < part_contents.dim_size(); ++j) { + CopyContentAt(part_contents, j, out_contents); + } + } + return true; + } + + if ((node->type_string() == "Gather" || node->type_string() == "GatherV2") && + recurse_input(0, &input_contents)) { + Tensor indices_tensor; + std::vector indices; + if (!GetInputConstTensor(node, 1, &indices_tensor) || + !GetTensorIntValues(indices_tensor, &indices)) { + return false; + } + + int64_t axis = 0; + if (node->type_string() == "GatherV2") { + Tensor axis_tensor; + std::vector axis_values; + if (!GetInputConstTensor(node, 2, &axis_tensor) || + !GetTensorIntValues(axis_tensor, &axis_values) || + axis_values.size() != 1) { + return false; + } + axis = axis_values[0]; + } + + const int64_t params_rank = 1; + if (axis < 0) axis += params_rank; + if (axis != 0) { + return false; + } + + const int64_t rank = input_contents.dim_size(); + for (int64_t index : indices) { + if (index < 0) index += rank; + if (index < 0 || index >= rank) { + return false; + } + CopyContentAt(input_contents, index, out_contents); + } + return true; + } + + if (node->type_string() == "Prod" && recurse_input(0, &input_contents)) { + Tensor reduction_indices_tensor; + std::vector axes; + bool keep_dims = false; + if (!GetInputConstTensor(node, 1, &reduction_indices_tensor) || + !GetTensorIntValues(reduction_indices_tensor, &axes) || + !GetNodeAttr(node->attrs(), "keep_dims", &keep_dims).ok() || + keep_dims || axes.size() != 1 || + (axes[0] != 0 && axes[0] != -1) || input_contents.dim_size() == 0) { + return false; + } + int64_t value = 1; + ExpressionProto expr = GetContentExpressionProto(input_contents, 0); + for (int64_t i = 0; i < input_contents.dim_size(); ++i) { + value *= input_contents.dim(i).size(); + if (i > 0) { + expr = MakeMulExpressionProto(std::move(expr), + GetContentExpressionProto(input_contents, i)); + } + } + out_contents->add_dim()->set_size(value); + out_contents->add_expressions()->Swap(&expr); + return true; + } + + if (node->type_string() == "Slice" && recurse_input(0, &input_contents)) { + Tensor begin_tensor; + Tensor size_tensor; + std::vector begin; + std::vector size; + if (!GetInputConstTensor(node, 1, &begin_tensor) || + !GetInputConstTensor(node, 2, &size_tensor) || + !GetTensorIntValues(begin_tensor, &begin) || + !GetTensorIntValues(size_tensor, &size) || begin.size() != 1 || + size.size() != 1) { + return false; + } + int64_t start = begin[0]; + if (start < 0 || start > input_contents.dim_size()) { + return false; + } + int64_t length = size[0] < 0 ? input_contents.dim_size() - start : size[0]; + if (length < 0 || start + length > input_contents.dim_size()) { + return false; + } + for (int64_t i = 0; i < length; ++i) { + CopyContentAt(input_contents, start + i, out_contents); + } + return true; + } + + if (node->type_string() == "StridedSlice" && + recurse_input(0, &input_contents)) { + Tensor begin_tensor; + Tensor end_tensor; + Tensor strides_tensor; + std::vector begin; + std::vector end; + std::vector strides; + int64_t begin_mask = 0; + int64_t end_mask = 0; + int64_t ellipsis_mask = 0; + int64_t new_axis_mask = 0; + int64_t shrink_axis_mask = 0; + if (!GetInputConstTensor(node, 1, &begin_tensor) || + !GetInputConstTensor(node, 2, &end_tensor) || + !GetInputConstTensor(node, 3, &strides_tensor) || + !GetTensorIntValues(begin_tensor, &begin) || + !GetTensorIntValues(end_tensor, &end) || + !GetTensorIntValues(strides_tensor, &strides) || begin.size() != 1 || + end.size() != 1 || strides.size() != 1 || + !GetNodeAttr(node->attrs(), "begin_mask", &begin_mask).ok() || + !GetNodeAttr(node->attrs(), "end_mask", &end_mask).ok() || + !GetNodeAttr(node->attrs(), "ellipsis_mask", &ellipsis_mask).ok() || + !GetNodeAttr(node->attrs(), "new_axis_mask", &new_axis_mask).ok() || + !GetNodeAttr(node->attrs(), "shrink_axis_mask", &shrink_axis_mask).ok()) { + return false; + } + if (ellipsis_mask != 0 || new_axis_mask != 0) { + return false; + } + const int64_t rank = input_contents.dim_size(); + int64_t stride = strides[0]; + if (stride == 0) { + return false; + } + int64_t start = (begin_mask & 1) ? (stride > 0 ? 0 : rank - 1) : begin[0]; + int64_t stop = (end_mask & 1) ? (stride > 0 ? rank : -1) : end[0]; + if (start < 0) start += rank; + if (stop < 0 && !(end_mask & 1 && stride < 0)) stop += rank; + if (shrink_axis_mask & 1) { + if (start < 0 || start >= rank) { + return false; + } + CopyContentAt(input_contents, start, out_contents); + return true; + } + if (stride < 0) { + return false; + } + start = std::max(0, start); + stop = std::min(rank, stop); + for (int64_t i = start; i < stop; i += stride) { + CopyContentAt(input_contents, i, out_contents); + } + return true; + } + + return false; +} + // For stateless RNGs ops, they are pure but device-dependent. Those ops are not // constant-foldable. static absl::flat_hash_set* kBlockList = @@ -273,16 +631,13 @@ bool IsConstantFoldable( const std::function& consider, int64_t max_constant_size_in_bytes, std::unordered_map>* shape_replacement_map) { - TensorShapeProto dynamic_shape; - if (GetShapeFromDirectDynamicSource(n, &dynamic_shape)) { - VLOG(1) << "Skipping constant folding for dynamic shape-derived node " - << n->name() << " op=" << n->type_string() - << " inferred_shape=" << dynamic_shape.DebugString(); - return false; - } - if (n->attrs().FindByString(kXlaShapeDerivedAttrName) != nullptr) { - VLOG(1) << "Skipping constant folding for shape-derived node " - << n->name() << " op=" << n->type_string(); + TensorShapeProto exact_contents; + const bool has_exact_contents = TryGetFoldedValueContents(n, 0, &exact_contents); + const bool has_dynamic = + GetShapeFromDirectDynamicSource(n, &exact_contents); + const bool is_shape_derived = + n->attrs().FindByString(kXlaShapeDerivedAttrName) != nullptr; + if ((has_dynamic || is_shape_derived) && (!has_exact_contents || n->num_outputs() > 1)) { return false; } if (n->IsConstant()) { @@ -492,6 +847,8 @@ void AddShapeNodeToConstantGraph( TensorShapeProto user_inferred_shape; const bool has_dynamic = GetShapeFromDirectDynamicSource(n, &user_inferred_shape); + TensorShapeProto exact_contents; + const bool has_exact_contents = TryGetFoldedValueContents(n, 0, &exact_contents); std::vector& added = (*node_map)[n]; const string& node_name = n->name(); for (const Tensor& t : shape_replacement_map.at(n)) { @@ -505,6 +862,9 @@ void AddShapeNodeToConstantGraph( builder.Attr("has_dynamic", has_dynamic) .Attr("user_inferred_shape", user_inferred_shape); } + if (has_exact_contents && HasDynamicDimExprs(exact_contents)) { + builder.Attr(kUserInferredValueContentsAttrName, exact_contents); + } NodeDef def; CHECK(builder.Finalize(&def).ok()); Node* constant_node; @@ -626,6 +986,24 @@ bool ReplaceTensorWithConstant( TensorShapeProto user_inferred_shape; const bool has_dynamic = GetShapeFromDirectDynamicSource(tensor.first, &user_inferred_shape); + const bool is_shape_derived = + tensor.first->attrs().FindByString(kXlaShapeDerivedAttrName) != nullptr; + if (tensor.second != 0 && (has_dynamic || is_shape_derived)) { + VLOG(1) << "Skipping replacement of " << tensor.first->name() << " :: " + << tensor.second + << " because symbolic content preservation is only supported for " + << "single-output replacements"; + return false; + } + TensorShapeProto exact_contents; + const bool has_exact_contents = + TryGetFoldedValueContents(tensor.first, tensor.second, &exact_contents); + if ((has_dynamic || is_shape_derived) && !has_exact_contents) { + VLOG(1) << "Skipping replacement of " << tensor.first->name() << " :: " + << tensor.second + << " because constant folding could not preserve symbolic contents"; + return false; + } Node* constant_node; auto builder = NodeDefBuilder(generate_new_name(graph, node_name), "Const") .Attr("dtype", constant.dtype()) @@ -634,6 +1012,10 @@ bool ReplaceTensorWithConstant( builder.Attr("has_dynamic", has_dynamic) .Attr("user_inferred_shape", user_inferred_shape); } + if (has_exact_contents && HasDynamicDimExprs(exact_contents)) { + builder.Attr("has_dynamic", true) + .Attr(kUserInferredValueContentsAttrName, exact_contents); + } if (partition_device) { builder.Device(partition_device->name()); } From 56b8330ef5a6a2de0109dc077aceda54efad91c3 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Mon, 27 Apr 2026 11:48:39 +0100 Subject: [PATCH 2/8] Serialize symbolic const contents attrs --- tensorflow/compiler/jit/kernels/xla_ops.cc | 20 +++++------- .../compiler/tf2xla/kernels/const_op.cc | 31 +++++++------------ .../core/common_runtime/constant_folding.cc | 16 ++++++++-- 3 files changed, 33 insertions(+), 34 deletions(-) diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index c8b0d2eee913d6..3c3e5842b153c1 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -523,6 +523,8 @@ absl::Status CompileToLocalExecutable( auto inferred_shape_it = attr_map.find("user_inferred_shape"); auto inferred_contents_it = attr_map.find("user_inferred_value_contents"); + auto inferred_contents_serialized_it = + attr_map.find("user_inferred_value_contents_serialized"); bool has_dynamic = false; auto has_dynamic_it = attr_map.find("has_dynamic"); if (has_dynamic_it != attr_map.end()) { @@ -530,12 +532,16 @@ absl::Status CompileToLocalExecutable( } if (inferred_contents_it == attr_map.end() && + inferred_contents_serialized_it == attr_map.end() && (!has_dynamic || inferred_shape_it == attr_map.end())) { return; } TensorShapeProto inferred_shape_proto; - if (inferred_contents_it != attr_map.end()) { + if (inferred_contents_serialized_it != attr_map.end()) { + inferred_shape_proto.ParseFromString( + inferred_contents_serialized_it->second.s()); + } else if (inferred_contents_it != attr_map.end()) { inferred_shape_proto = inferred_contents_it->second.shape(); } else { inferred_shape_proto = inferred_shape_it->second.shape(); @@ -546,11 +552,6 @@ absl::Status CompileToLocalExecutable( arg.constant_value.NumElements() == inferred_shape.dims()) || (TensorShapeUtils::IsScalar(arg.constant_value.shape()) && inferred_shape.dims() == 1))) { - VLOG(1) << "XlaCompileOp const arg " << arg_index - << " node=" << node_name - << " has dynamic shape metadata but tensor shape " - << arg.constant_value.shape().DebugString() - << " does not match inferred rank " << inferred_shape.dims(); return; } @@ -566,18 +567,11 @@ absl::Status CompileToLocalExecutable( } else if (arg.constant_value.dtype() == DT_INT64) { expr.set_constant_value(arg.constant_value.flat()(i)); } else { - VLOG(1) << "XlaCompileOp const arg " << arg_index - << " node=" << node_name - << " has unsupported dtype for inferred shape contents: " - << DataTypeString(arg.constant_value.dtype()); arg.constant_value_expressions.clear(); return; } arg.constant_value_expressions.push_back(std::move(expr)); } - VLOG(1) << "XlaCompileOp recovered " << arg.constant_value_expressions.size() - << " constant_value_expressions for const arg " << arg_index - << " node=" << node_name << " from user_inferred_shape"; }; auto record_dynamic_dim_value = [&](int64_t dim_size, xla::DExpr expr) { if (!saw_dynamic_dim_value) { diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index cc3b3cff365895..1c1c3874493b76 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -227,19 +227,24 @@ class ConstOp : public XlaOpKernel { bool has_dynamic = false; TensorShapeProto inferred_shape_proto; TensorShapeProto inferred_value_contents_proto; + string inferred_value_contents_serialized; if (GetNodeAttr(ctx->op_kernel().def(), "has_dynamic", &has_dynamic).ok() && has_dynamic && GetNodeAttr(ctx->op_kernel().def(), "user_inferred_shape", &inferred_shape_proto) .ok()) { - VLOG(1) << "ConstOp recovered dynamic folded-const metadata with " - << "inferred_shape=" << inferred_shape_proto.DebugString() - << " dynamic_exprs=" - << CountDynamicShapeContents(inferred_shape_proto); } - GetNodeAttr(ctx->op_kernel().def(), "user_inferred_value_contents", - &inferred_value_contents_proto) - .IgnoreError(); + if (GetNodeAttr(ctx->op_kernel().def(), + "user_inferred_value_contents_serialized", + &inferred_value_contents_serialized) + .ok()) { + inferred_value_contents_proto.ParseFromString( + inferred_value_contents_serialized); + } else { + GetNodeAttr(ctx->op_kernel().def(), "user_inferred_value_contents", + &inferred_value_contents_proto) + .IgnoreError(); + } const bool has_contents_proto = inferred_value_contents_proto.dim_size() > 0; const TensorShapeProto& contents_proto = has_contents_proto ? inferred_value_contents_proto : inferred_shape_proto; @@ -261,10 +266,6 @@ class ConstOp : public XlaOpKernel { XlaExpression::XlaOp(broadcast, ctx->expected_output_dtype(0)); if ((has_contents_proto || has_dynamic) && CanAttachContentsFromTensorShapeProto(shape, contents_proto)) { - VLOG(1) << "ConstOp attaching shape contents through broadcast fast " - << "path with " << shape.num_elements() - << " entries and dynamic_exprs=" - << CountDynamicShapeContents(contents_proto); output.set_contents( BuildShapeContentsFromTensorShapeProto(contents_proto)); } @@ -277,17 +278,9 @@ class ConstOp : public XlaOpKernel { OP_REQUIRES(ctx, tensor.FromProto(cpu_allocator(), proto_), errors::InvalidArgument("Cannot parse tensor from proto: ", proto_.DebugString())); - if (has_contents_proto || has_dynamic) { - VLOG(1) << "ConstOp tensor path tensor_shape=" - << tensor.shape().DebugString() << " inferred_rank=" - << contents_proto.dim_size(); - } XlaExpression output = XlaExpression::Constant(tensor); if ((has_contents_proto || has_dynamic) && CanAttachContentsFromTensorShapeProto(tensor.shape(), contents_proto)) { - VLOG(1) << "ConstOp attaching shape contents to folded const with " - << tensor.NumElements() << " entries and dynamic_exprs=" - << CountDynamicShapeContents(contents_proto); output.set_contents( BuildShapeContentsFromTensorShapeProto(contents_proto)); } diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 681ea1a280b77a..9a5bf9e723ea41 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -54,6 +54,8 @@ namespace { const char kScopedAllocatorAttrName[] = "_scoped_allocator"; const char kXlaShapeDerivedAttrName[] = "_xla_shape_derived"; const char kUserInferredValueContentsAttrName[] = "user_inferred_value_contents"; +const char kUserInferredValueContentsSerializedAttrName[] = + "user_inferred_value_contents_serialized"; bool IsShapeOp(const Node* n); @@ -190,6 +192,14 @@ bool TryGetFoldedValueContents(const Node* node, int output_index, return false; } + string serialized_contents; + if (GetNodeAttr(node->attrs(), kUserInferredValueContentsSerializedAttrName, + &serialized_contents) + .ok() && + out_contents->ParseFromString(serialized_contents)) { + return true; + } + TensorShapeProto existing_contents; if (GetNodeAttr(node->attrs(), kUserInferredValueContentsAttrName, &existing_contents) @@ -863,7 +873,8 @@ void AddShapeNodeToConstantGraph( .Attr("user_inferred_shape", user_inferred_shape); } if (has_exact_contents && HasDynamicDimExprs(exact_contents)) { - builder.Attr(kUserInferredValueContentsAttrName, exact_contents); + builder.Attr(kUserInferredValueContentsSerializedAttrName, + exact_contents.SerializeAsString()); } NodeDef def; CHECK(builder.Finalize(&def).ok()); @@ -1014,7 +1025,8 @@ bool ReplaceTensorWithConstant( } if (has_exact_contents && HasDynamicDimExprs(exact_contents)) { builder.Attr("has_dynamic", true) - .Attr(kUserInferredValueContentsAttrName, exact_contents); + .Attr(kUserInferredValueContentsSerializedAttrName, + exact_contents.SerializeAsString()); } if (partition_device) { builder.Device(partition_device->name()); From f254e39966196ea2b93941bd487bcded5338b8a0 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Fri, 24 Apr 2026 19:06:20 +0100 Subject: [PATCH 3/8] Avoid partial non-CPU multi-output constant replacement --- tensorflow/core/common_runtime/constant_folding.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 9a5bf9e723ea41..9cee31219b3c81 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -966,6 +966,12 @@ bool ReplaceTensorWithConstant( ? DeviceType{partition_device->device_type()} : DEVICE_CPU; if (partition_device && device_type != DEVICE_CPU) { + // Constant folding replaces one output edge-set at a time. Be + // conservative for non-CPU multi-output ops, since partially replacing a + // node can violate per-output placement or memory-type assumptions. + if (tensor.first->num_outputs() > 1) { + return false; + } MemoryTypeVector input_mvec; MemoryTypeVector output_mvec; if (!MemoryTypesForNode(graph->op_registry(), device_type, From 3cda43856a57fa1c1a2dff932530ac999457849b Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Mon, 27 Apr 2026 11:38:23 +0100 Subject: [PATCH 4/8] Add symbolic constant folding coverage --- tensorflow/core/common_runtime/BUILD | 4 + .../common_runtime/constant_folding_test.cc | 201 ++++++++++++++++++ 2 files changed, 205 insertions(+) diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index 301015eba61fe8..24089b11472b6b 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -2754,6 +2754,7 @@ tf_cc_test( ":direct_session_internal", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", "//tensorflow/cc:sendrecv_ops", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -2769,9 +2770,12 @@ tf_cc_test( "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:concat_op", "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:gather_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:immutable_constant_op", "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:reshape_op", + "//tensorflow/core/kernels:slice_op", "//tensorflow/core/kernels:topk_op", "@eigen_archive//:eigen3", ], diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc index 481a85add4893c..84b5f324f9a62c 100644 --- a/tensorflow/core/common_runtime/constant_folding_test.cc +++ b/tensorflow/core/common_runtime/constant_folding_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/cc/ops/array_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/nn_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -32,6 +33,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape_expr.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/node_builder.h" @@ -45,6 +47,15 @@ limitations under the License. namespace tensorflow { namespace { +TensorShapeProto MakeDynamicShapeProto977x16() { + TensorShapeProto proto; + proto.add_dim()->set_size(977); + proto.add_dim()->set_size(16); + proto.add_expressions()->set_variable_id(0); + proto.add_expressions()->set_constant_value(16); + return proto; +} + class ConstantFoldingTest : public ::testing::Test { protected: template @@ -634,6 +645,196 @@ TEST_F(ConstantFoldingTest, ConstShapeKnown) { } } +TEST_F(ConstantFoldingTest, FoldShapeFromDynamicArgPreservesContents) { + Graph g(OpRegistry::Global()); + Scope s = Scope::NewRootScope(); + auto arg = ops::_Arg(s.WithOpName("arg"), DT_FLOAT, 0); + auto shape = ops::Shape(s.WithOpName("shape"), arg); + auto send = ops::_Send(s.WithOpName("send"), shape, "send", "sender", 0, + "receiver"); + TF_ASSERT_OK(s.ToGraph(&g)); + + std::unordered_map index_by_name = g.BuildNodeNameIndex(); + Node* arg_node = index_by_name.at("arg"); + arg_node->AddAttr("_output_shapes", + std::vector{MakeDynamicShapeProto977x16()}); + + PartialTensorShape partial_shape({977, 16}); + std::unordered_map> shape_map; + shape_map[arg_node->name()].push_back(partial_shape); + + ConstantFoldingOptions opts; + opts.shape_map = &shape_map; + bool was_mutated = false; + TF_ASSERT_OK( + ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated)); + + index_by_name = g.BuildNodeNameIndex(); + Node* send_node = index_by_name.at("send"); + const Edge* send_input = nullptr; + TF_ASSERT_OK(send_node->input_edge(0, &send_input)); + Node* folded = send_input->src(); + ExpectNodeEqual(folded, {977, 16}, {2}); + + // This test checks the stable contract we care about: after folding + // Shape(arg), the replacement Const still carries symbolic contents. + string serialized_contents_proto; + TF_ASSERT_OK(GetNodeAttr(folded->attrs(), + "user_inferred_value_contents_serialized", + &serialized_contents_proto)); + TensorShapeProto contents_proto; + ASSERT_TRUE(contents_proto.ParseFromString(serialized_contents_proto)); + ASSERT_EQ(contents_proto.expressions_size(), 2); + EXPECT_TRUE(IsDynamicDimExpr(contents_proto.expressions(0))); + EXPECT_FALSE(IsDynamicDimExpr(contents_proto.expressions(1))); + EXPECT_EQ(contents_proto.dim(0).size(), 977); + EXPECT_EQ(contents_proto.dim(1).size(), 16); +} + +TEST_F(ConstantFoldingTest, FoldSliceOfDynamicShapePreservesContents) { + Graph g(OpRegistry::Global()); + Scope s = Scope::NewRootScope(); + auto arg = ops::_Arg(s.WithOpName("arg"), DT_FLOAT, 0); + auto shape = ops::Shape(s.WithOpName("shape"), arg); + auto begin = ops::Const(s.WithOpName("begin"), {0}, {1}); + auto size = ops::Const(s.WithOpName("size"), {1}, {1}); + auto slice = ops::Slice(s.WithOpName("slice"), shape, begin, size); + auto send = ops::_Send(s.WithOpName("send"), slice, "send", "sender", 0, + "receiver"); + TF_ASSERT_OK(s.ToGraph(&g)); + + std::unordered_map index_by_name = g.BuildNodeNameIndex(); + Node* arg_node = index_by_name.at("arg"); + arg_node->AddAttr("_output_shapes", + std::vector{MakeDynamicShapeProto977x16()}); + + PartialTensorShape partial_shape({977, 16}); + std::unordered_map> shape_map; + shape_map[arg_node->name()].push_back(partial_shape); + + ConstantFoldingOptions opts; + opts.shape_map = &shape_map; + bool was_mutated = false; + TF_ASSERT_OK( + ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated)); + + index_by_name = g.BuildNodeNameIndex(); + Node* send_node = index_by_name.at("send"); + const Edge* send_input = nullptr; + TF_ASSERT_OK(send_node->input_edge(0, &send_input)); + Node* folded = send_input->src(); + ExpectNodeEqual(folded, {977}, {1}); + + // Folding Slice(Shape(arg), [0], [1]) should preserve the selected symbolic + // content, not just the concrete value 977. + string serialized_contents_proto; + TF_ASSERT_OK(GetNodeAttr(folded->attrs(), + "user_inferred_value_contents_serialized", + &serialized_contents_proto)); + TensorShapeProto contents_proto; + ASSERT_TRUE(contents_proto.ParseFromString(serialized_contents_proto)); + ASSERT_EQ(contents_proto.expressions_size(), 1); + EXPECT_TRUE(IsDynamicDimExpr(contents_proto.expressions(0))); + EXPECT_EQ(contents_proto.dim(0).size(), 977); +} + +TEST_F(ConstantFoldingTest, FoldGatherOfDynamicShapePreservesContents) { + Graph g(OpRegistry::Global()); + Scope s = Scope::NewRootScope(); + auto arg = ops::_Arg(s.WithOpName("arg"), DT_FLOAT, 0); + auto shape = ops::Shape(s.WithOpName("shape"), arg); + auto index = ops::Const(s.WithOpName("index"), 0); + auto gather = ops::GatherV2(s.WithOpName("gather"), shape, index, + ops::Const(s.WithOpName("axis"), 0)); + auto send = ops::_Send(s.WithOpName("send"), gather, "send", "sender", 0, + "receiver"); + TF_ASSERT_OK(s.ToGraph(&g)); + + std::unordered_map index_by_name = g.BuildNodeNameIndex(); + Node* arg_node = index_by_name.at("arg"); + arg_node->AddAttr("_output_shapes", + std::vector{MakeDynamicShapeProto977x16()}); + + PartialTensorShape partial_shape({977, 16}); + std::unordered_map> shape_map; + shape_map[arg_node->name()].push_back(partial_shape); + + ConstantFoldingOptions opts; + opts.shape_map = &shape_map; + bool was_mutated = false; + TF_ASSERT_OK( + ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated)); + + index_by_name = g.BuildNodeNameIndex(); + Node* send_node = index_by_name.at("send"); + const Edge* send_input = nullptr; + TF_ASSERT_OK(send_node->input_edge(0, &send_input)); + Node* folded = send_input->src(); + ExpectNodeEqual(folded, {977}, {}); + + // Folding Gather(Shape(arg), 0, axis=0) should preserve the selected + // symbolic content, not just the concrete scalar value 977. + string serialized_contents_proto; + TF_ASSERT_OK(GetNodeAttr(folded->attrs(), + "user_inferred_value_contents_serialized", + &serialized_contents_proto)); + TensorShapeProto contents_proto; + ASSERT_TRUE(contents_proto.ParseFromString(serialized_contents_proto)); + ASSERT_EQ(contents_proto.expressions_size(), 1); + EXPECT_TRUE(IsDynamicDimExpr(contents_proto.expressions(0))); + EXPECT_EQ(contents_proto.dim(0).size(), 977); +} + +TEST_F(ConstantFoldingTest, FoldReshapeOfDynamicShapePreservesContents) { + Graph g(OpRegistry::Global()); + Scope s = Scope::NewRootScope(); + auto arg = ops::_Arg(s.WithOpName("arg"), DT_FLOAT, 0); + auto shape = ops::Shape(s.WithOpName("shape"), arg); + auto begin = ops::Const(s.WithOpName("begin"), {0}, {1}); + auto size = ops::Const(s.WithOpName("size"), {1}, {1}); + auto slice = ops::Slice(s.WithOpName("slice"), shape, begin, size); + auto scalar_shape = ops::Const(s.WithOpName("scalar_shape"), {}); + auto reshape = + ops::Reshape(s.WithOpName("reshape"), slice, scalar_shape); + auto send = ops::_Send(s.WithOpName("send"), reshape, "send", "sender", 0, + "receiver"); + TF_ASSERT_OK(s.ToGraph(&g)); + + std::unordered_map index_by_name = g.BuildNodeNameIndex(); + Node* arg_node = index_by_name.at("arg"); + arg_node->AddAttr("_output_shapes", + std::vector{MakeDynamicShapeProto977x16()}); + + PartialTensorShape partial_shape({977, 16}); + std::unordered_map> shape_map; + shape_map[arg_node->name()].push_back(partial_shape); + + ConstantFoldingOptions opts; + opts.shape_map = &shape_map; + bool was_mutated = false; + TF_ASSERT_OK( + ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated)); + + index_by_name = g.BuildNodeNameIndex(); + Node* send_node = index_by_name.at("send"); + const Edge* send_input = nullptr; + TF_ASSERT_OK(send_node->input_edge(0, &send_input)); + Node* folded = send_input->src(); + ExpectNodeEqual(folded, {977}, {}); + + // Folding Reshape(Slice(Shape(arg), [0], [1]), []) should preserve the + // selected symbolic content when the shape-vector result is scalarized. + string serialized_contents_proto; + TF_ASSERT_OK(GetNodeAttr(folded->attrs(), + "user_inferred_value_contents_serialized", + &serialized_contents_proto)); + TensorShapeProto contents_proto; + ASSERT_TRUE(contents_proto.ParseFromString(serialized_contents_proto)); + ASSERT_EQ(contents_proto.expressions_size(), 1); + EXPECT_TRUE(IsDynamicDimExpr(contents_proto.expressions(0))); + EXPECT_EQ(contents_proto.dim(0).size(), 977); +} + TEST_F(ConstantFoldingTest, NoReplacePartialOutput) { Graph g(OpRegistry::Global()); { From ad7ecc23db86e1873def6bcb4100177e91584ebf Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Mon, 27 Apr 2026 12:08:05 +0100 Subject: [PATCH 5/8] Revert "Serialize symbolic const contents attrs" This reverts commit 56b8330ef5a6a2de0109dc077aceda54efad91c3. --- tensorflow/compiler/jit/kernels/xla_ops.cc | 20 +++++++----- .../compiler/tf2xla/kernels/const_op.cc | 31 ++++++++++++------- .../core/common_runtime/constant_folding.cc | 16 ++-------- 3 files changed, 34 insertions(+), 33 deletions(-) diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 3c3e5842b153c1..c8b0d2eee913d6 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -523,8 +523,6 @@ absl::Status CompileToLocalExecutable( auto inferred_shape_it = attr_map.find("user_inferred_shape"); auto inferred_contents_it = attr_map.find("user_inferred_value_contents"); - auto inferred_contents_serialized_it = - attr_map.find("user_inferred_value_contents_serialized"); bool has_dynamic = false; auto has_dynamic_it = attr_map.find("has_dynamic"); if (has_dynamic_it != attr_map.end()) { @@ -532,16 +530,12 @@ absl::Status CompileToLocalExecutable( } if (inferred_contents_it == attr_map.end() && - inferred_contents_serialized_it == attr_map.end() && (!has_dynamic || inferred_shape_it == attr_map.end())) { return; } TensorShapeProto inferred_shape_proto; - if (inferred_contents_serialized_it != attr_map.end()) { - inferred_shape_proto.ParseFromString( - inferred_contents_serialized_it->second.s()); - } else if (inferred_contents_it != attr_map.end()) { + if (inferred_contents_it != attr_map.end()) { inferred_shape_proto = inferred_contents_it->second.shape(); } else { inferred_shape_proto = inferred_shape_it->second.shape(); @@ -552,6 +546,11 @@ absl::Status CompileToLocalExecutable( arg.constant_value.NumElements() == inferred_shape.dims()) || (TensorShapeUtils::IsScalar(arg.constant_value.shape()) && inferred_shape.dims() == 1))) { + VLOG(1) << "XlaCompileOp const arg " << arg_index + << " node=" << node_name + << " has dynamic shape metadata but tensor shape " + << arg.constant_value.shape().DebugString() + << " does not match inferred rank " << inferred_shape.dims(); return; } @@ -567,11 +566,18 @@ absl::Status CompileToLocalExecutable( } else if (arg.constant_value.dtype() == DT_INT64) { expr.set_constant_value(arg.constant_value.flat()(i)); } else { + VLOG(1) << "XlaCompileOp const arg " << arg_index + << " node=" << node_name + << " has unsupported dtype for inferred shape contents: " + << DataTypeString(arg.constant_value.dtype()); arg.constant_value_expressions.clear(); return; } arg.constant_value_expressions.push_back(std::move(expr)); } + VLOG(1) << "XlaCompileOp recovered " << arg.constant_value_expressions.size() + << " constant_value_expressions for const arg " << arg_index + << " node=" << node_name << " from user_inferred_shape"; }; auto record_dynamic_dim_value = [&](int64_t dim_size, xla::DExpr expr) { if (!saw_dynamic_dim_value) { diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index 1c1c3874493b76..cc3b3cff365895 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -227,24 +227,19 @@ class ConstOp : public XlaOpKernel { bool has_dynamic = false; TensorShapeProto inferred_shape_proto; TensorShapeProto inferred_value_contents_proto; - string inferred_value_contents_serialized; if (GetNodeAttr(ctx->op_kernel().def(), "has_dynamic", &has_dynamic).ok() && has_dynamic && GetNodeAttr(ctx->op_kernel().def(), "user_inferred_shape", &inferred_shape_proto) .ok()) { + VLOG(1) << "ConstOp recovered dynamic folded-const metadata with " + << "inferred_shape=" << inferred_shape_proto.DebugString() + << " dynamic_exprs=" + << CountDynamicShapeContents(inferred_shape_proto); } - if (GetNodeAttr(ctx->op_kernel().def(), - "user_inferred_value_contents_serialized", - &inferred_value_contents_serialized) - .ok()) { - inferred_value_contents_proto.ParseFromString( - inferred_value_contents_serialized); - } else { - GetNodeAttr(ctx->op_kernel().def(), "user_inferred_value_contents", - &inferred_value_contents_proto) - .IgnoreError(); - } + GetNodeAttr(ctx->op_kernel().def(), "user_inferred_value_contents", + &inferred_value_contents_proto) + .IgnoreError(); const bool has_contents_proto = inferred_value_contents_proto.dim_size() > 0; const TensorShapeProto& contents_proto = has_contents_proto ? inferred_value_contents_proto : inferred_shape_proto; @@ -266,6 +261,10 @@ class ConstOp : public XlaOpKernel { XlaExpression::XlaOp(broadcast, ctx->expected_output_dtype(0)); if ((has_contents_proto || has_dynamic) && CanAttachContentsFromTensorShapeProto(shape, contents_proto)) { + VLOG(1) << "ConstOp attaching shape contents through broadcast fast " + << "path with " << shape.num_elements() + << " entries and dynamic_exprs=" + << CountDynamicShapeContents(contents_proto); output.set_contents( BuildShapeContentsFromTensorShapeProto(contents_proto)); } @@ -278,9 +277,17 @@ class ConstOp : public XlaOpKernel { OP_REQUIRES(ctx, tensor.FromProto(cpu_allocator(), proto_), errors::InvalidArgument("Cannot parse tensor from proto: ", proto_.DebugString())); + if (has_contents_proto || has_dynamic) { + VLOG(1) << "ConstOp tensor path tensor_shape=" + << tensor.shape().DebugString() << " inferred_rank=" + << contents_proto.dim_size(); + } XlaExpression output = XlaExpression::Constant(tensor); if ((has_contents_proto || has_dynamic) && CanAttachContentsFromTensorShapeProto(tensor.shape(), contents_proto)) { + VLOG(1) << "ConstOp attaching shape contents to folded const with " + << tensor.NumElements() << " entries and dynamic_exprs=" + << CountDynamicShapeContents(contents_proto); output.set_contents( BuildShapeContentsFromTensorShapeProto(contents_proto)); } diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 9cee31219b3c81..c2c2eed6261af2 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -54,8 +54,6 @@ namespace { const char kScopedAllocatorAttrName[] = "_scoped_allocator"; const char kXlaShapeDerivedAttrName[] = "_xla_shape_derived"; const char kUserInferredValueContentsAttrName[] = "user_inferred_value_contents"; -const char kUserInferredValueContentsSerializedAttrName[] = - "user_inferred_value_contents_serialized"; bool IsShapeOp(const Node* n); @@ -192,14 +190,6 @@ bool TryGetFoldedValueContents(const Node* node, int output_index, return false; } - string serialized_contents; - if (GetNodeAttr(node->attrs(), kUserInferredValueContentsSerializedAttrName, - &serialized_contents) - .ok() && - out_contents->ParseFromString(serialized_contents)) { - return true; - } - TensorShapeProto existing_contents; if (GetNodeAttr(node->attrs(), kUserInferredValueContentsAttrName, &existing_contents) @@ -873,8 +863,7 @@ void AddShapeNodeToConstantGraph( .Attr("user_inferred_shape", user_inferred_shape); } if (has_exact_contents && HasDynamicDimExprs(exact_contents)) { - builder.Attr(kUserInferredValueContentsSerializedAttrName, - exact_contents.SerializeAsString()); + builder.Attr(kUserInferredValueContentsAttrName, exact_contents); } NodeDef def; CHECK(builder.Finalize(&def).ok()); @@ -1031,8 +1020,7 @@ bool ReplaceTensorWithConstant( } if (has_exact_contents && HasDynamicDimExprs(exact_contents)) { builder.Attr("has_dynamic", true) - .Attr(kUserInferredValueContentsSerializedAttrName, - exact_contents.SerializeAsString()); + .Attr(kUserInferredValueContentsAttrName, exact_contents); } if (partition_device) { builder.Device(partition_device->name()); From 8cf3f2d1f61ce65fcfc079c5282b8f141fdc6439 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Mon, 27 Apr 2026 12:26:41 +0100 Subject: [PATCH 6/8] Reapply "Serialize symbolic const contents attrs" This reverts commit ad7ecc23db86e1873def6bcb4100177e91584ebf. --- tensorflow/compiler/jit/kernels/xla_ops.cc | 20 +++++------- .../compiler/tf2xla/kernels/const_op.cc | 31 +++++++------------ .../core/common_runtime/constant_folding.cc | 16 ++++++++-- 3 files changed, 33 insertions(+), 34 deletions(-) diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index c8b0d2eee913d6..3c3e5842b153c1 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -523,6 +523,8 @@ absl::Status CompileToLocalExecutable( auto inferred_shape_it = attr_map.find("user_inferred_shape"); auto inferred_contents_it = attr_map.find("user_inferred_value_contents"); + auto inferred_contents_serialized_it = + attr_map.find("user_inferred_value_contents_serialized"); bool has_dynamic = false; auto has_dynamic_it = attr_map.find("has_dynamic"); if (has_dynamic_it != attr_map.end()) { @@ -530,12 +532,16 @@ absl::Status CompileToLocalExecutable( } if (inferred_contents_it == attr_map.end() && + inferred_contents_serialized_it == attr_map.end() && (!has_dynamic || inferred_shape_it == attr_map.end())) { return; } TensorShapeProto inferred_shape_proto; - if (inferred_contents_it != attr_map.end()) { + if (inferred_contents_serialized_it != attr_map.end()) { + inferred_shape_proto.ParseFromString( + inferred_contents_serialized_it->second.s()); + } else if (inferred_contents_it != attr_map.end()) { inferred_shape_proto = inferred_contents_it->second.shape(); } else { inferred_shape_proto = inferred_shape_it->second.shape(); @@ -546,11 +552,6 @@ absl::Status CompileToLocalExecutable( arg.constant_value.NumElements() == inferred_shape.dims()) || (TensorShapeUtils::IsScalar(arg.constant_value.shape()) && inferred_shape.dims() == 1))) { - VLOG(1) << "XlaCompileOp const arg " << arg_index - << " node=" << node_name - << " has dynamic shape metadata but tensor shape " - << arg.constant_value.shape().DebugString() - << " does not match inferred rank " << inferred_shape.dims(); return; } @@ -566,18 +567,11 @@ absl::Status CompileToLocalExecutable( } else if (arg.constant_value.dtype() == DT_INT64) { expr.set_constant_value(arg.constant_value.flat()(i)); } else { - VLOG(1) << "XlaCompileOp const arg " << arg_index - << " node=" << node_name - << " has unsupported dtype for inferred shape contents: " - << DataTypeString(arg.constant_value.dtype()); arg.constant_value_expressions.clear(); return; } arg.constant_value_expressions.push_back(std::move(expr)); } - VLOG(1) << "XlaCompileOp recovered " << arg.constant_value_expressions.size() - << " constant_value_expressions for const arg " << arg_index - << " node=" << node_name << " from user_inferred_shape"; }; auto record_dynamic_dim_value = [&](int64_t dim_size, xla::DExpr expr) { if (!saw_dynamic_dim_value) { diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index cc3b3cff365895..1c1c3874493b76 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -227,19 +227,24 @@ class ConstOp : public XlaOpKernel { bool has_dynamic = false; TensorShapeProto inferred_shape_proto; TensorShapeProto inferred_value_contents_proto; + string inferred_value_contents_serialized; if (GetNodeAttr(ctx->op_kernel().def(), "has_dynamic", &has_dynamic).ok() && has_dynamic && GetNodeAttr(ctx->op_kernel().def(), "user_inferred_shape", &inferred_shape_proto) .ok()) { - VLOG(1) << "ConstOp recovered dynamic folded-const metadata with " - << "inferred_shape=" << inferred_shape_proto.DebugString() - << " dynamic_exprs=" - << CountDynamicShapeContents(inferred_shape_proto); } - GetNodeAttr(ctx->op_kernel().def(), "user_inferred_value_contents", - &inferred_value_contents_proto) - .IgnoreError(); + if (GetNodeAttr(ctx->op_kernel().def(), + "user_inferred_value_contents_serialized", + &inferred_value_contents_serialized) + .ok()) { + inferred_value_contents_proto.ParseFromString( + inferred_value_contents_serialized); + } else { + GetNodeAttr(ctx->op_kernel().def(), "user_inferred_value_contents", + &inferred_value_contents_proto) + .IgnoreError(); + } const bool has_contents_proto = inferred_value_contents_proto.dim_size() > 0; const TensorShapeProto& contents_proto = has_contents_proto ? inferred_value_contents_proto : inferred_shape_proto; @@ -261,10 +266,6 @@ class ConstOp : public XlaOpKernel { XlaExpression::XlaOp(broadcast, ctx->expected_output_dtype(0)); if ((has_contents_proto || has_dynamic) && CanAttachContentsFromTensorShapeProto(shape, contents_proto)) { - VLOG(1) << "ConstOp attaching shape contents through broadcast fast " - << "path with " << shape.num_elements() - << " entries and dynamic_exprs=" - << CountDynamicShapeContents(contents_proto); output.set_contents( BuildShapeContentsFromTensorShapeProto(contents_proto)); } @@ -277,17 +278,9 @@ class ConstOp : public XlaOpKernel { OP_REQUIRES(ctx, tensor.FromProto(cpu_allocator(), proto_), errors::InvalidArgument("Cannot parse tensor from proto: ", proto_.DebugString())); - if (has_contents_proto || has_dynamic) { - VLOG(1) << "ConstOp tensor path tensor_shape=" - << tensor.shape().DebugString() << " inferred_rank=" - << contents_proto.dim_size(); - } XlaExpression output = XlaExpression::Constant(tensor); if ((has_contents_proto || has_dynamic) && CanAttachContentsFromTensorShapeProto(tensor.shape(), contents_proto)) { - VLOG(1) << "ConstOp attaching shape contents to folded const with " - << tensor.NumElements() << " entries and dynamic_exprs=" - << CountDynamicShapeContents(contents_proto); output.set_contents( BuildShapeContentsFromTensorShapeProto(contents_proto)); } diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index c2c2eed6261af2..9cee31219b3c81 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -54,6 +54,8 @@ namespace { const char kScopedAllocatorAttrName[] = "_scoped_allocator"; const char kXlaShapeDerivedAttrName[] = "_xla_shape_derived"; const char kUserInferredValueContentsAttrName[] = "user_inferred_value_contents"; +const char kUserInferredValueContentsSerializedAttrName[] = + "user_inferred_value_contents_serialized"; bool IsShapeOp(const Node* n); @@ -190,6 +192,14 @@ bool TryGetFoldedValueContents(const Node* node, int output_index, return false; } + string serialized_contents; + if (GetNodeAttr(node->attrs(), kUserInferredValueContentsSerializedAttrName, + &serialized_contents) + .ok() && + out_contents->ParseFromString(serialized_contents)) { + return true; + } + TensorShapeProto existing_contents; if (GetNodeAttr(node->attrs(), kUserInferredValueContentsAttrName, &existing_contents) @@ -863,7 +873,8 @@ void AddShapeNodeToConstantGraph( .Attr("user_inferred_shape", user_inferred_shape); } if (has_exact_contents && HasDynamicDimExprs(exact_contents)) { - builder.Attr(kUserInferredValueContentsAttrName, exact_contents); + builder.Attr(kUserInferredValueContentsSerializedAttrName, + exact_contents.SerializeAsString()); } NodeDef def; CHECK(builder.Finalize(&def).ok()); @@ -1020,7 +1031,8 @@ bool ReplaceTensorWithConstant( } if (has_exact_contents && HasDynamicDimExprs(exact_contents)) { builder.Attr("has_dynamic", true) - .Attr(kUserInferredValueContentsAttrName, exact_contents); + .Attr(kUserInferredValueContentsSerializedAttrName, + exact_contents.SerializeAsString()); } if (partition_device) { builder.Device(partition_device->name()); From d4e5ccb6d823619663567c427b2a0b2f1aef361d Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Mon, 27 Apr 2026 13:28:55 +0100 Subject: [PATCH 7/8] Address symbolic constant folding review feedback --- tensorflow/compiler/jit/kernels/xla_ops.cc | 26 +++++- .../compiler/tf2xla/kernels/const_op.cc | 29 +++++- .../core/common_runtime/constant_folding.cc | 93 ++++++++++++++++--- .../common_runtime/constant_folding_test.cc | 8 +- 4 files changed, 132 insertions(+), 24 deletions(-) diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 3c3e5842b153c1..48e5aeb2a1dcb0 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -101,6 +101,14 @@ limitations under the License. namespace tensorflow { namespace { +constexpr char kUserInferredValueContentsAttrName[] = + "_user_inferred_value_contents"; +constexpr char kUserInferredValueContentsSerializedAttrName[] = + "_user_inferred_value_contents_serialized"; +constexpr char kLegacyUserInferredValueContentsAttrName[] = + "user_inferred_value_contents"; +constexpr char kLegacyUserInferredValueContentsSerializedAttrName[] = + "user_inferred_value_contents_serialized"; using XlaDeviceCompiler = DeviceCompiler; using PjRtDeviceCompiler = @@ -522,9 +530,17 @@ absl::Status CompileToLocalExecutable( auto inferred_shape_it = attr_map.find("user_inferred_shape"); auto inferred_contents_it = - attr_map.find("user_inferred_value_contents"); + attr_map.find(kUserInferredValueContentsAttrName); + if (inferred_contents_it == attr_map.end()) { + inferred_contents_it = + attr_map.find(kLegacyUserInferredValueContentsAttrName); + } auto inferred_contents_serialized_it = - attr_map.find("user_inferred_value_contents_serialized"); + attr_map.find(kUserInferredValueContentsSerializedAttrName); + if (inferred_contents_serialized_it == attr_map.end()) { + inferred_contents_serialized_it = + attr_map.find(kLegacyUserInferredValueContentsSerializedAttrName); + } bool has_dynamic = false; auto has_dynamic_it = attr_map.find("has_dynamic"); if (has_dynamic_it != attr_map.end()) { @@ -539,8 +555,10 @@ absl::Status CompileToLocalExecutable( TensorShapeProto inferred_shape_proto; if (inferred_contents_serialized_it != attr_map.end()) { - inferred_shape_proto.ParseFromString( - inferred_contents_serialized_it->second.s()); + if (!inferred_shape_proto.ParseFromString( + inferred_contents_serialized_it->second.s())) { + return; + } } else if (inferred_contents_it != attr_map.end()) { inferred_shape_proto = inferred_contents_it->second.shape(); } else { diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index 1c1c3874493b76..b480baffd02b3f 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -35,6 +35,15 @@ limitations under the License. namespace tensorflow { namespace { +constexpr char kUserInferredValueContentsAttrName[] = + "_user_inferred_value_contents"; +constexpr char kUserInferredValueContentsSerializedAttrName[] = + "_user_inferred_value_contents_serialized"; +constexpr char kLegacyUserInferredValueContentsAttrName[] = + "user_inferred_value_contents"; +constexpr char kLegacyUserInferredValueContentsSerializedAttrName[] = + "user_inferred_value_contents_serialized"; + template DstT CastTo(SrcT src) { return static_cast(src); @@ -235,15 +244,27 @@ class ConstOp : public XlaOpKernel { .ok()) { } if (GetNodeAttr(ctx->op_kernel().def(), - "user_inferred_value_contents_serialized", + kUserInferredValueContentsSerializedAttrName, + &inferred_value_contents_serialized) + .ok() || + GetNodeAttr(ctx->op_kernel().def(), + kLegacyUserInferredValueContentsSerializedAttrName, &inferred_value_contents_serialized) .ok()) { - inferred_value_contents_proto.ParseFromString( - inferred_value_contents_serialized); + if (!inferred_value_contents_proto.ParseFromString( + inferred_value_contents_serialized)) { + inferred_value_contents_proto.Clear(); + } } else { - GetNodeAttr(ctx->op_kernel().def(), "user_inferred_value_contents", + GetNodeAttr(ctx->op_kernel().def(), kUserInferredValueContentsAttrName, &inferred_value_contents_proto) .IgnoreError(); + if (inferred_value_contents_proto.dim_size() == 0) { + GetNodeAttr(ctx->op_kernel().def(), + kLegacyUserInferredValueContentsAttrName, + &inferred_value_contents_proto) + .IgnoreError(); + } } const bool has_contents_proto = inferred_value_contents_proto.dim_size() > 0; const TensorShapeProto& contents_proto = diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 9cee31219b3c81..4bbcd04db68863 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -53,8 +53,13 @@ namespace { const char kScopedAllocatorAttrName[] = "_scoped_allocator"; const char kXlaShapeDerivedAttrName[] = "_xla_shape_derived"; -const char kUserInferredValueContentsAttrName[] = "user_inferred_value_contents"; +const char kUserInferredValueContentsAttrName[] = + "_user_inferred_value_contents"; const char kUserInferredValueContentsSerializedAttrName[] = + "_user_inferred_value_contents_serialized"; +const char kLegacyUserInferredValueContentsAttrName[] = + "user_inferred_value_contents"; +const char kLegacyUserInferredValueContentsSerializedAttrName[] = "user_inferred_value_contents_serialized"; bool IsShapeOp(const Node* n); @@ -83,6 +88,72 @@ bool GetShapeFromDirectDynamicSource(const Node* node, GetShapeFromArgNode(node, out_shape); } +bool TryParseSerializedContentsAttr(const AttrSlice& attrs, + TensorShapeProto* out_contents) { + string serialized_contents; + if (!GetNodeAttr(attrs, kUserInferredValueContentsSerializedAttrName, + &serialized_contents) + .ok() && + !GetNodeAttr(attrs, kLegacyUserInferredValueContentsSerializedAttrName, + &serialized_contents) + .ok()) { + return false; + } + out_contents->Clear(); + return out_contents->ParseFromString(serialized_contents); +} + +bool TryGetContentsProtoAttr(const AttrSlice& attrs, + TensorShapeProto* out_contents) { + if (TryParseSerializedContentsAttr(attrs, out_contents)) { + return true; + } + if (GetNodeAttr(attrs, kUserInferredValueContentsAttrName, out_contents).ok()) { + return true; + } + return GetNodeAttr(attrs, kLegacyUserInferredValueContentsAttrName, + out_contents) + .ok(); +} + +bool HasTransitiveDynamicShapeContents( + const Node* node, std::unordered_map* memo) { + auto it = memo->find(node); + if (it != memo->end()) { + return it->second; + } + + TensorShapeProto contents_proto; + if (TryGetContentsProtoAttr(node->attrs(), &contents_proto) && + HasDynamicDimExprs(contents_proto)) { + return (*memo)[node] = true; + } + + bool has_dynamic = false; + TensorShapeProto inferred_shape_proto; + if (GetNodeAttr(node->attrs(), "has_dynamic", &has_dynamic).ok() && + has_dynamic && + GetNodeAttr(node->attrs(), "user_inferred_shape", &inferred_shape_proto) + .ok() && + HasDynamicDimExprs(inferred_shape_proto)) { + return (*memo)[node] = true; + } + + if (GetShapeFromDirectDynamicSource(node, &inferred_shape_proto) || + node->attrs().FindByString(kXlaShapeDerivedAttrName) != nullptr) { + return (*memo)[node] = true; + } + + for (const Edge* edge : node->in_edges()) { + if (edge->IsControlEdge()) continue; + if (HasTransitiveDynamicShapeContents(edge->src(), memo)) { + return (*memo)[node] = true; + } + } + + return (*memo)[node] = false; +} + bool GetConstTensor(const Node* node, Tensor* tensor) { if (node == nullptr || !node->IsConstant()) { return false; @@ -192,18 +263,12 @@ bool TryGetFoldedValueContents(const Node* node, int output_index, return false; } - string serialized_contents; - if (GetNodeAttr(node->attrs(), kUserInferredValueContentsSerializedAttrName, - &serialized_contents) - .ok() && - out_contents->ParseFromString(serialized_contents)) { + if (TryParseSerializedContentsAttr(node->attrs(), out_contents)) { return true; } TensorShapeProto existing_contents; - if (GetNodeAttr(node->attrs(), kUserInferredValueContentsAttrName, - &existing_contents) - .ok()) { + if (TryGetContentsProtoAttr(node->attrs(), &existing_contents)) { *out_contents = existing_contents; return true; } @@ -643,11 +708,15 @@ bool IsConstantFoldable( std::unordered_map>* shape_replacement_map) { TensorShapeProto exact_contents; const bool has_exact_contents = TryGetFoldedValueContents(n, 0, &exact_contents); - const bool has_dynamic = - GetShapeFromDirectDynamicSource(n, &exact_contents); + TensorShapeProto dynamic_shape; + const bool has_dynamic = GetShapeFromDirectDynamicSource(n, &dynamic_shape); const bool is_shape_derived = n->attrs().FindByString(kXlaShapeDerivedAttrName) != nullptr; - if ((has_dynamic || is_shape_derived) && (!has_exact_contents || n->num_outputs() > 1)) { + std::unordered_map dynamic_contents_memo; + const bool has_transitive_dynamic_contents = + HasTransitiveDynamicShapeContents(n, &dynamic_contents_memo); + if ((has_dynamic || is_shape_derived || has_transitive_dynamic_contents) && + (!has_exact_contents || n->num_outputs() > 1)) { return false; } if (n->IsConstant()) { diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc index 84b5f324f9a62c..7ccde4e6f0b4d4 100644 --- a/tensorflow/core/common_runtime/constant_folding_test.cc +++ b/tensorflow/core/common_runtime/constant_folding_test.cc @@ -680,7 +680,7 @@ TEST_F(ConstantFoldingTest, FoldShapeFromDynamicArgPreservesContents) { // Shape(arg), the replacement Const still carries symbolic contents. string serialized_contents_proto; TF_ASSERT_OK(GetNodeAttr(folded->attrs(), - "user_inferred_value_contents_serialized", + "_user_inferred_value_contents_serialized", &serialized_contents_proto)); TensorShapeProto contents_proto; ASSERT_TRUE(contents_proto.ParseFromString(serialized_contents_proto)); @@ -729,7 +729,7 @@ TEST_F(ConstantFoldingTest, FoldSliceOfDynamicShapePreservesContents) { // content, not just the concrete value 977. string serialized_contents_proto; TF_ASSERT_OK(GetNodeAttr(folded->attrs(), - "user_inferred_value_contents_serialized", + "_user_inferred_value_contents_serialized", &serialized_contents_proto)); TensorShapeProto contents_proto; ASSERT_TRUE(contents_proto.ParseFromString(serialized_contents_proto)); @@ -776,7 +776,7 @@ TEST_F(ConstantFoldingTest, FoldGatherOfDynamicShapePreservesContents) { // symbolic content, not just the concrete scalar value 977. string serialized_contents_proto; TF_ASSERT_OK(GetNodeAttr(folded->attrs(), - "user_inferred_value_contents_serialized", + "_user_inferred_value_contents_serialized", &serialized_contents_proto)); TensorShapeProto contents_proto; ASSERT_TRUE(contents_proto.ParseFromString(serialized_contents_proto)); @@ -826,7 +826,7 @@ TEST_F(ConstantFoldingTest, FoldReshapeOfDynamicShapePreservesContents) { // selected symbolic content when the shape-vector result is scalarized. string serialized_contents_proto; TF_ASSERT_OK(GetNodeAttr(folded->attrs(), - "user_inferred_value_contents_serialized", + "_user_inferred_value_contents_serialized", &serialized_contents_proto)); TensorShapeProto contents_proto; ASSERT_TRUE(contents_proto.ParseFromString(serialized_contents_proto)); From b10866744a101248f389f0705d86a4c95de93cac Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Mon, 27 Apr 2026 13:42:37 +0100 Subject: [PATCH 8/8] Simplify symbolic contents attr contract --- tensorflow/compiler/jit/kernels/xla_ops.cc | 23 ++--------------- .../compiler/tf2xla/kernels/const_op.cc | 23 +---------------- .../core/common_runtime/constant_folding.cc | 25 +++---------------- .../common_runtime/constant_folding_test.cc | 8 +++--- 4 files changed, 11 insertions(+), 68 deletions(-) diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 48e5aeb2a1dcb0..e02f930079215b 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -103,12 +103,6 @@ namespace tensorflow { namespace { constexpr char kUserInferredValueContentsAttrName[] = "_user_inferred_value_contents"; -constexpr char kUserInferredValueContentsSerializedAttrName[] = - "_user_inferred_value_contents_serialized"; -constexpr char kLegacyUserInferredValueContentsAttrName[] = - "user_inferred_value_contents"; -constexpr char kLegacyUserInferredValueContentsSerializedAttrName[] = - "user_inferred_value_contents_serialized"; using XlaDeviceCompiler = DeviceCompiler; using PjRtDeviceCompiler = @@ -531,16 +525,6 @@ absl::Status CompileToLocalExecutable( auto inferred_shape_it = attr_map.find("user_inferred_shape"); auto inferred_contents_it = attr_map.find(kUserInferredValueContentsAttrName); - if (inferred_contents_it == attr_map.end()) { - inferred_contents_it = - attr_map.find(kLegacyUserInferredValueContentsAttrName); - } - auto inferred_contents_serialized_it = - attr_map.find(kUserInferredValueContentsSerializedAttrName); - if (inferred_contents_serialized_it == attr_map.end()) { - inferred_contents_serialized_it = - attr_map.find(kLegacyUserInferredValueContentsSerializedAttrName); - } bool has_dynamic = false; auto has_dynamic_it = attr_map.find("has_dynamic"); if (has_dynamic_it != attr_map.end()) { @@ -548,19 +532,16 @@ absl::Status CompileToLocalExecutable( } if (inferred_contents_it == attr_map.end() && - inferred_contents_serialized_it == attr_map.end() && (!has_dynamic || inferred_shape_it == attr_map.end())) { return; } TensorShapeProto inferred_shape_proto; - if (inferred_contents_serialized_it != attr_map.end()) { + if (inferred_contents_it != attr_map.end()) { if (!inferred_shape_proto.ParseFromString( - inferred_contents_serialized_it->second.s())) { + inferred_contents_it->second.s())) { return; } - } else if (inferred_contents_it != attr_map.end()) { - inferred_shape_proto = inferred_contents_it->second.shape(); } else { inferred_shape_proto = inferred_shape_it->second.shape(); } diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index b480baffd02b3f..e3650249405aa4 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -37,12 +37,6 @@ namespace { constexpr char kUserInferredValueContentsAttrName[] = "_user_inferred_value_contents"; -constexpr char kUserInferredValueContentsSerializedAttrName[] = - "_user_inferred_value_contents_serialized"; -constexpr char kLegacyUserInferredValueContentsAttrName[] = - "user_inferred_value_contents"; -constexpr char kLegacyUserInferredValueContentsSerializedAttrName[] = - "user_inferred_value_contents_serialized"; template DstT CastTo(SrcT src) { @@ -243,28 +237,13 @@ class ConstOp : public XlaOpKernel { &inferred_shape_proto) .ok()) { } - if (GetNodeAttr(ctx->op_kernel().def(), - kUserInferredValueContentsSerializedAttrName, - &inferred_value_contents_serialized) - .ok() || - GetNodeAttr(ctx->op_kernel().def(), - kLegacyUserInferredValueContentsSerializedAttrName, + if (GetNodeAttr(ctx->op_kernel().def(), kUserInferredValueContentsAttrName, &inferred_value_contents_serialized) .ok()) { if (!inferred_value_contents_proto.ParseFromString( inferred_value_contents_serialized)) { inferred_value_contents_proto.Clear(); } - } else { - GetNodeAttr(ctx->op_kernel().def(), kUserInferredValueContentsAttrName, - &inferred_value_contents_proto) - .IgnoreError(); - if (inferred_value_contents_proto.dim_size() == 0) { - GetNodeAttr(ctx->op_kernel().def(), - kLegacyUserInferredValueContentsAttrName, - &inferred_value_contents_proto) - .IgnoreError(); - } } const bool has_contents_proto = inferred_value_contents_proto.dim_size() > 0; const TensorShapeProto& contents_proto = diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 4bbcd04db68863..ebb050cec35bb3 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -55,12 +55,6 @@ const char kScopedAllocatorAttrName[] = "_scoped_allocator"; const char kXlaShapeDerivedAttrName[] = "_xla_shape_derived"; const char kUserInferredValueContentsAttrName[] = "_user_inferred_value_contents"; -const char kUserInferredValueContentsSerializedAttrName[] = - "_user_inferred_value_contents_serialized"; -const char kLegacyUserInferredValueContentsAttrName[] = - "user_inferred_value_contents"; -const char kLegacyUserInferredValueContentsSerializedAttrName[] = - "user_inferred_value_contents_serialized"; bool IsShapeOp(const Node* n); @@ -91,10 +85,7 @@ bool GetShapeFromDirectDynamicSource(const Node* node, bool TryParseSerializedContentsAttr(const AttrSlice& attrs, TensorShapeProto* out_contents) { string serialized_contents; - if (!GetNodeAttr(attrs, kUserInferredValueContentsSerializedAttrName, - &serialized_contents) - .ok() && - !GetNodeAttr(attrs, kLegacyUserInferredValueContentsSerializedAttrName, + if (!GetNodeAttr(attrs, kUserInferredValueContentsAttrName, &serialized_contents) .ok()) { return false; @@ -105,15 +96,7 @@ bool TryParseSerializedContentsAttr(const AttrSlice& attrs, bool TryGetContentsProtoAttr(const AttrSlice& attrs, TensorShapeProto* out_contents) { - if (TryParseSerializedContentsAttr(attrs, out_contents)) { - return true; - } - if (GetNodeAttr(attrs, kUserInferredValueContentsAttrName, out_contents).ok()) { - return true; - } - return GetNodeAttr(attrs, kLegacyUserInferredValueContentsAttrName, - out_contents) - .ok(); + return TryParseSerializedContentsAttr(attrs, out_contents); } bool HasTransitiveDynamicShapeContents( @@ -942,7 +925,7 @@ void AddShapeNodeToConstantGraph( .Attr("user_inferred_shape", user_inferred_shape); } if (has_exact_contents && HasDynamicDimExprs(exact_contents)) { - builder.Attr(kUserInferredValueContentsSerializedAttrName, + builder.Attr(kUserInferredValueContentsAttrName, exact_contents.SerializeAsString()); } NodeDef def; @@ -1100,7 +1083,7 @@ bool ReplaceTensorWithConstant( } if (has_exact_contents && HasDynamicDimExprs(exact_contents)) { builder.Attr("has_dynamic", true) - .Attr(kUserInferredValueContentsSerializedAttrName, + .Attr(kUserInferredValueContentsAttrName, exact_contents.SerializeAsString()); } if (partition_device) { diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc index 7ccde4e6f0b4d4..18ae0e74a007b6 100644 --- a/tensorflow/core/common_runtime/constant_folding_test.cc +++ b/tensorflow/core/common_runtime/constant_folding_test.cc @@ -680,7 +680,7 @@ TEST_F(ConstantFoldingTest, FoldShapeFromDynamicArgPreservesContents) { // Shape(arg), the replacement Const still carries symbolic contents. string serialized_contents_proto; TF_ASSERT_OK(GetNodeAttr(folded->attrs(), - "_user_inferred_value_contents_serialized", + "_user_inferred_value_contents", &serialized_contents_proto)); TensorShapeProto contents_proto; ASSERT_TRUE(contents_proto.ParseFromString(serialized_contents_proto)); @@ -729,7 +729,7 @@ TEST_F(ConstantFoldingTest, FoldSliceOfDynamicShapePreservesContents) { // content, not just the concrete value 977. string serialized_contents_proto; TF_ASSERT_OK(GetNodeAttr(folded->attrs(), - "_user_inferred_value_contents_serialized", + "_user_inferred_value_contents", &serialized_contents_proto)); TensorShapeProto contents_proto; ASSERT_TRUE(contents_proto.ParseFromString(serialized_contents_proto)); @@ -776,7 +776,7 @@ TEST_F(ConstantFoldingTest, FoldGatherOfDynamicShapePreservesContents) { // symbolic content, not just the concrete scalar value 977. string serialized_contents_proto; TF_ASSERT_OK(GetNodeAttr(folded->attrs(), - "_user_inferred_value_contents_serialized", + "_user_inferred_value_contents", &serialized_contents_proto)); TensorShapeProto contents_proto; ASSERT_TRUE(contents_proto.ParseFromString(serialized_contents_proto)); @@ -826,7 +826,7 @@ TEST_F(ConstantFoldingTest, FoldReshapeOfDynamicShapePreservesContents) { // selected symbolic content when the shape-vector result is scalarized. string serialized_contents_proto; TF_ASSERT_OK(GetNodeAttr(folded->attrs(), - "_user_inferred_value_contents_serialized", + "_user_inferred_value_contents", &serialized_contents_proto)); TensorShapeProto contents_proto; ASSERT_TRUE(contents_proto.ParseFromString(serialized_contents_proto));