From b09f90cbf5a3756240d7ebd8b62357ca91e0e647 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Sat, 4 Apr 2026 15:09:50 +0100 Subject: [PATCH] Error when dynamic sizes hit MLIR tf2xla --- .../compiler/tf2xla/mlir_xla_op_kernel.cc | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc index b1a93508d92896..33e9ed240aeaa7 100644 --- a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" #include +#include #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -35,7 +36,9 @@ limitations under the License. #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/resource_base.h" #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/status.h" @@ -68,6 +71,68 @@ class MLIRContextResource : public ResourceBase { mlir::MLIRContext mlir_ctx_; }; +bool HasDynamicExpressions(const TensorShape& shape) { + for (const auto& expr : shape.get_expressions()) { + if (expr && expr->is_dynamic()) { + return true; + } + } + return false; +} + +bool HasDynamicExpressions(const xla::Shape& shape) { + for (const auto& expr : shape.expressions()) { + if (expr && expr->is_dynamic()) { + return true; + } + } + return false; +} + +absl::Status RejectDynamicShapeExpressionsInMlirXlaOpKernel( + llvm::ArrayRef args, const Graph& graph) { + for (int i = 0; i < args.size(); ++i) { + const auto& shape = args[i].shape; + const bool has_dynamic_exprs = + std::holds_alternative(shape) + ? HasDynamicExpressions(std::get(shape)) + : HasDynamicExpressions(std::get(shape)); + if (has_dynamic_exprs) { + return errors::Unimplemented( + "MlirXlaOpKernel does not support dynamic shape expressions. " + "Argument ", + i, " carries a dynamic expression."); + } + } + + for (Node* node : graph.nodes()) { + for (const auto& name_attr_pair : node->attrs()) { + const auto& attr_name = name_attr_pair.first; + const auto& attr_value = name_attr_pair.second; + auto maybe_reject_shape = [&](const TensorShapeProto& shape_proto) { + const TensorShape shape(shape_proto); + if (!HasDynamicExpressions(shape)) { + return absl::OkStatus(); + } + return errors::Unimplemented( + "MlirXlaOpKernel does not support dynamic shape expressions. " + "Node '", + node->name(), "' attribute '", attr_name, + "' carries a dynamic expression."); + }; + + if (attr_value.value_case() == AttrValue::kShape) { + TF_RETURN_IF_ERROR(maybe_reject_shape(attr_value.shape())); + } else if (attr_value.value_case() == AttrValue::kList) { + for (const auto& shape_proto : attr_value.list().shape()) { + TF_RETURN_IF_ERROR(maybe_reject_shape(shape_proto)); + } + } + } + } + return absl::OkStatus(); +} + } // namespace absl::Status MlirXlaOpKernel::ContextToXlaArgs( @@ -143,6 +208,8 @@ absl::Status MlirXlaOpKernel::ConstructXlaOp(XlaOpKernelContext* ctx) { // Create a graph that wraps the kernel. TF_ASSIGN_OR_RETURN(auto graph, CreateSingleOpGraph(def(), xla_args, result_dtypes)); + TF_RETURN_IF_ERROR( + RejectDynamicShapeExpressionsInMlirXlaOpKernel(xla_args, *graph)); ResourceMgr* res_manager = ctx->op_kernel_context()->resource_manager(); MLIRContextResource* ctx_res;