From d53ae674223cbf6152304380de18e8738f173306 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Sat, 21 Mar 2026 23:51:46 +0000 Subject: [PATCH 1/8] Require explicit expressions in shape builders Thread expressions through dense shape builders --- tensorflow/compiler/tf2xla/shape_util.cc | 30 +++++---- .../tpu/kernels/tpu_compile_op_support.cc | 7 +- ...riton_xla_extract_insert_to_triton_pass.cc | 4 +- .../xla/xla/hlo/parser/hlo_parser_test.cc | 5 +- .../collectives/collective_quantizer.cc | 1 + .../transforms/expanders/reduce_decomposer.cc | 1 + .../translate/mhlo_to_hlo/type_to_shape.cc | 5 +- third_party/xla/xla/literal.cc | 1 + third_party/xla/xla/literal_util.h | 4 ++ third_party/xla/xla/pjrt/utils.cc | 3 +- .../xla/xla/service/gpu/matmul_utils.cc | 7 +- .../service/gpu/transforms/gemm_rewriter.cc | 2 + .../gpu/transforms/horizontal_loop_fusion.cc | 2 + .../transforms/reduction_layout_normalizer.cc | 1 + .../gpu/transforms/windowed_einsum_handler.cc | 2 + .../xla/xla/service/layout_assignment.cc | 2 + .../xla/xla/service/llvm_ir/ir_array.h | 5 +- third_party/xla/xla/shape_test.cc | 18 ++++- third_party/xla/xla/shape_util.cc | 53 +++++++++------ third_party/xla/xla/shape_util.h | 66 ++++++++++++------- 20 files changed, 147 insertions(+), 72 deletions(-) diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index f2dccaea7b1cac..e516763482f58a 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -123,7 +123,8 @@ absl::Status TensorShapeToBoundedXLAShape( TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); if (tensor_shape.unknown_rank()) { // For unknown shape, create a rank 1 size 0 tensor. - *shape = xla::ShapeUtil::MakeShapeWithDenseLayout(type, {0}, {0}); + *shape = xla::ShapeUtil::MakeShapeWithDenseLayout( + type, {0}, {xla::DynExpr::_(0)}, {0}); return absl::OkStatus(); } @@ -151,8 +152,15 @@ absl::Status TensorShapeToBoundedXLAShape( } // XLA uses minor-to-major; Tensorflow uses major-to-minor. std::iota(layout.rbegin(), layout.rend(), 0); + std::vector expressions(rank); + for (int d = 0; d < rank; ++d) { + expressions[d] = tensor_shape.dim_size(d) < 0 + ? xla::DynExpr::_(dimensions[d]) + : xla::DynExpr::_(tensor_shape.dim_size(d)); + } xla::Shape result = - xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout); + xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, expressions, + layout); for (int d = 0; d < rank; ++d) { if (tensor_shape.dim_size(d) < 0) { result.set_dynamic_dimension(d, true); @@ -166,7 +174,8 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, const PartialTensorShape& tensor_shape) { if (tensor_shape.unknown_rank()) { // For unknown shape, create a rank 1 size 0 tensor. - return xla::ShapeUtil::MakeShapeWithDenseLayout(type, {0}, {0}); + return xla::ShapeUtil::MakeShapeWithDenseLayout( + type, {0}, {xla::DynExpr::_(0)}, {0}); } int rank = tensor_shape.dims(); std::vector dimensions(rank); @@ -177,16 +186,15 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, if (dimensions[d] < 0) { LOG(WARNING) << "Unable to convert TF shape with dynamic size to XLA " "shape; returning unknown sentinel value"; - return xla::ShapeUtil::MakeShapeWithDenseLayout(type, {0}, {0}); + return xla::ShapeUtil::MakeShapeWithDenseLayout( + type, {0}, {xla::DynExpr::_(0)}, {0}); } expressions[d] = tensor_shape.get_expression(d); } // XLA uses minor-to-major; Tensorflow uses major-to-minor. std::iota(layout.rbegin(), layout.rend(), 0); - xla::Shape result = - xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout); - result.set_expressions(expressions); - return result; + return xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, + expressions, layout); } // Convert a TensorShape into the equivalent XLA Shape proto. @@ -225,10 +233,8 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, // XLA uses minor-to-major; Tensorflow uses major-to-minor. std::iota(layout.rbegin(), layout.rend(), 0); - auto shape = - xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout); - shape.set_expressions(expressions); - return shape; + return xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, + expressions, layout); } absl::StatusOr> GetShapeLayoutVector(const xla::Shape& shape) { diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc index 8fe6537b5168a8..a123d94d5edbac 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc @@ -214,11 +214,14 @@ Shape GetPerDeviceShape(const Shape& shape, const HloSharding& sharding, for (int64_t i = 0; i < limit.size(); ++i) { dimensions[i] = limit[i] - offset[i]; } + auto expressions = xla::ShapeUtil::MakeConstantExpressions(dimensions); if (shape.has_layout()) { return xla::ShapeUtil::MakeShapeWithDenseLayout( - shape.element_type(), dimensions, shape.layout().minor_to_major()); + shape.element_type(), dimensions, expressions, + shape.layout().minor_to_major()); } - return xla::ShapeUtil::MakeShape(shape.element_type(), dimensions); + return xla::ShapeUtil::MakeShape(shape.element_type(), dimensions, + expressions); } absl::Status AddVariableUpdatesToCores( diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc index 4b6d01a662f0fc..c3fab46b2a3ec5 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc @@ -216,7 +216,9 @@ Value ComputeLinearOffset(::xla::EmitterLocOpBuilder& builder, ValueRange offsets, llvm::ArrayRef layout) { ::xla::Shape shape = ::xla::ShapeUtil::MakeShapeWithDenseLayout( xgt::GetPrimitiveType(tensor_type.getElementType()).value(), - tensor_type.getShape(), layout); + tensor_type.getShape(), + ::xla::ShapeUtil::MakeConstantExpressions(tensor_type.getShape()), + layout); ::xla::Shape linear_shape = ::xla::ShapeUtil::MakeShape( shape.element_type(), {::xla::ShapeUtil::ElementsIn(shape)}); diff --git a/third_party/xla/xla/hlo/parser/hlo_parser_test.cc b/third_party/xla/xla/hlo/parser/hlo_parser_test.cc index f49e109a69f1e2..a5450e52785344 100644 --- a/third_party/xla/xla/hlo/parser/hlo_parser_test.cc +++ b/third_party/xla/xla/hlo/parser/hlo_parser_test.cc @@ -5886,7 +5886,7 @@ TEST_F(HloParserTest, ParseBufferMoreThanOneElement) { TEST_F(HloParserTest, ParseBufferScalar) { std::string shape_string = "b(s32[])"; TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); - Shape expected = ShapeUtil::MakeBufferShape(S32, {}); + Shape expected = ShapeUtil::MakeBufferShape(S32, {}, {}); ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) << "expected: " << ShapeUtil::HumanString(expected) << "actual: " << ShapeUtil::HumanString(actual); @@ -5895,7 +5895,8 @@ TEST_F(HloParserTest, ParseBufferScalar) { TEST_F(HloParserTest, ParseBufferArray) { std::string shape_string = "b(f32[8,16]{1,0})"; TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); - Shape expected = ShapeUtil::MakeBufferShape(F32, {8, 16}); + std::vector expressions = {DynExpr::_(8), DynExpr::_(16)}; + Shape expected = ShapeUtil::MakeBufferShape(F32, {8, 16}, expressions); ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) << "expected: " << ShapeUtil::HumanString(expected) << "actual: " << ShapeUtil::HumanString(actual); diff --git a/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer.cc b/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer.cc index b3c2ffe79ec00c..fc02b7daeff488 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer.cc @@ -121,6 +121,7 @@ HloInstruction* ApplyUnaries(HloInstruction* instr, instr = instr->AddInstruction(unary->CloneWithNewOperands( ShapeUtil::MakeShapeWithDenseLayout( instr->shape().element_type(), unary->shape().dimensions(), + unary->shape().get_expressions(), unary->shape().layout().minor_to_major()), {instr})); } diff --git a/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.cc b/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.cc index 2fe502429287b4..7d1f27aa59ae9f 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.cc @@ -47,6 +47,7 @@ class VariadicReductionLayoutEqualizer : public DfsHloRewriteVisitor { if (first_input_s.layout() != input_s.layout()) { Shape new_input_s = ShapeUtil::MakeShapeWithDenseLayout( input_s.element_type(), input_s.dimensions(), + input_s.get_expressions(), first_input_s.layout().minor_to_major()); auto copy = MakeCopyHlo(input, new_input_s); changed = true; diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc index 9369b89d77ec61..ead92527f31451 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc @@ -130,8 +130,9 @@ Shape TypeToShape(mlir::Type type) { llvm::SmallVector dimensions(m.getShape().begin(), m.getShape().end()); + auto expressions = ::xla::ShapeUtil::MakeConstantExpressions(dimensions); return ::xla::ShapeUtil::MakeShapeWithDenseLayout( - primitive_type, dimensions, minor_to_major); + primitive_type, dimensions, expressions, minor_to_major); } else if (auto t = mlir::dyn_cast(type)) { // TODO(jpienaar): This is only handling the base case with primitive // element type. @@ -188,7 +189,7 @@ Shape TypeToShape(mlir::Type type) { auto final_ordering = mlir::applyPermutationMap( dimToLvl, llvm::ArrayRef(ordering)); auto sparse_shape = ::xla::ShapeUtil::MakeShapeWithSparseLayout( - primitive_type, shape, final_ordering); + primitive_type, shape, expressions, final_ordering); return sparse_shape; } diff --git a/third_party/xla/xla/literal.cc b/third_party/xla/xla/literal.cc index d91cfe064922c6..4dc62679461e50 100644 --- a/third_party/xla/xla/literal.cc +++ b/third_party/xla/xla/literal.cc @@ -1303,6 +1303,7 @@ Literal LiteralBase::Slice(absl::Span start_indices, } auto result_shape = ShapeUtil::MakeShapeWithDenseLayout( shape().element_type(), result_dimensions, + ShapeUtil::MakeConstantExpressions(result_dimensions), LayoutUtil::MinorToMajor(shape())); ShapeUtil::CopyDynamicDimensions(&result_shape, shape()); Literal result_literal(result_shape); diff --git a/third_party/xla/xla/literal_util.h b/third_party/xla/xla/literal_util.h index 4f8568110bca32..076725fad64865 100644 --- a/third_party/xla/xla/literal_util.h +++ b/third_party/xla/xla/literal_util.h @@ -351,6 +351,9 @@ template primitive_util::NativeToPrimitiveType(), {static_cast(values.size()), static_cast(values.begin()->size())}, + ShapeUtil::MakeConstantExpressions( + {static_cast(values.size()), + static_cast(values.begin()->size())}), layout.minor_to_major())); literal.PopulateR2(values); return literal; @@ -438,6 +441,7 @@ template const Array& values, const Layout& layout) { Literal literal(ShapeUtil::MakeShapeWithDenseLayout( primitive_util::NativeToPrimitiveType(), values.dimensions(), + ShapeUtil::MakeConstantExpressions(values.dimensions()), layout.minor_to_major())); literal.PopulateFromArray(values); return literal; diff --git a/third_party/xla/xla/pjrt/utils.cc b/third_party/xla/xla/pjrt/utils.cc index 7c203dc6cc101a..87828f0f429b4d 100644 --- a/third_party/xla/xla/pjrt/utils.cc +++ b/third_party/xla/xla/pjrt/utils.cc @@ -877,8 +877,9 @@ absl::StatusOr MakeShapeWithTrivialByteStrides( byte_stride *= dimensions[d]; } } + auto expressions = ShapeUtil::MakeConstantExpressions(dimensions); return ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, - minor_to_major); + expressions, minor_to_major); } absl::Status TestBufferDonationClashes( diff --git a/third_party/xla/xla/service/gpu/matmul_utils.cc b/third_party/xla/xla/service/gpu/matmul_utils.cc index 7a203388fd8893..33abdbd887ea47 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils.cc +++ b/third_party/xla/xla/service/gpu/matmul_utils.cc @@ -104,10 +104,11 @@ absl::StatusOr GetBatchRowColumnShape( }); }; + std::vector dimensions = {dim_size(batch_dims), dim_size(row_dims), + dim_size(col_dims)}; return ShapeUtil::MakeShapeWithDenseLayout( - shape.element_type(), - {dim_size(batch_dims), dim_size(row_dims), dim_size(col_dims)}, - minor_to_major); + shape.element_type(), dimensions, + ShapeUtil::MakeConstantExpressions(dimensions), minor_to_major); } // Returns the matrix layout for a logical shape (batch, rows, columns). diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc index 5054a440778105..fdd1a6dc0da03a 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc @@ -1316,6 +1316,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { x = instr->AddInstruction(op.first->CloneWithNewOperands( ShapeUtil::MakeShapeWithDenseLayout( x->shape().element_type(), op.first->shape().dimensions(), + op.first->shape().get_expressions(), op.first->shape().layout().minor_to_major()), operands)); } @@ -1378,6 +1379,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { instr->AddInstruction(HloInstruction::CreateCustomCall( ShapeUtil::MakeShapeWithDenseLayout( instr->shape().element_type(), new_output_shape.dimensions(), + new_output_shape.get_expressions(), instr->shape().layout().minor_to_major()), operands_list, kCublasLtMatmulF8CallTarget)); TF_RETURN_IF_ERROR(new_custom_call->set_backend_config(gpu_backend_config)); diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc index 6f3d1d91144224..ed725a2da74ca9 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc @@ -568,6 +568,8 @@ absl::Status HorizontalLoopFusionImpl::CreateFusedComputation( Shape new_shape = ShapeUtil::MakeShapeWithDenseLayout( new_output->shape().element_type(), {ShapeUtil::ElementsIn(new_output->shape())}, + ShapeUtil::MakeConstantExpressions( + {ShapeUtil::ElementsIn(new_output->shape())}), /*minor_to_major=*/std::vector(1, 0)); TF_ASSIGN_OR_RETURN(instr_outputs[j], MakeReshapeHlo(new_shape, new_output)); diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc index 38b8878eea128e..12b00e78c30f49 100644 --- a/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc @@ -123,6 +123,7 @@ class EnforceMinorToMajorReduceOpVisitor : public DfsHloRewriteVisitor { operand_shape.element_type(), new_operand_shape_data); Shape new_reduce_shape = ShapeUtil::MakeShapeWithDenseLayout( reduce_shape.element_type(), new_reduce_shape_data, + ShapeUtil::MakeConstantExpressions(new_reduce_shape_data), new_reduce_shape_layout); if (new_operand_shape == operand_shape && reduce->inputs().size() == 1) { diff --git a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc index ce454624144803..5021ded8d2b5d1 100644 --- a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc +++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc @@ -183,11 +183,13 @@ absl::StatusOr ShiftDequantizationF8( for (HloInstruction* unary : unaries[k]) { Shape new_shape = ShapeUtil::MakeShapeWithDenseLayout( operands[k]->shape().element_type(), unary->shape().dimensions(), + unary->shape().get_expressions(), unary->shape().layout().minor_to_major()); operands[k] = unary->AddInstruction(unary->CloneWithNewOperands( ShapeUtil::MakeShapeWithDenseLayout( operands[k]->shape().element_type(), unary->shape().dimensions(), + unary->shape().get_expressions(), unary->shape().layout().minor_to_major()), {operands[k]})); } diff --git a/third_party/xla/xla/service/layout_assignment.cc b/third_party/xla/xla/service/layout_assignment.cc index b5adc212b53a17..e5fd82c6f599cf 100644 --- a/third_party/xla/xla/service/layout_assignment.cc +++ b/third_party/xla/xla/service/layout_assignment.cc @@ -1401,6 +1401,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( const Shape& output_shape = instruction->shape(); Shape output_shape_with_layout = ShapeUtil::MakeShapeWithDenseLayout( output_shape.element_type(), output_shape.dimensions(), + output_shape.get_expressions(), LayoutUtil::MinorToMajor(output_layout)); Shape operand_shape = operand->shape(); *operand_shape.mutable_layout() = @@ -1539,6 +1540,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( } Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithDenseLayout( operand->shape().element_type(), operand->shape().dimensions(), + operand->shape().get_expressions(), LayoutUtil::MinorToMajor(operand_layout)); Shape output_shape = user->shape(); *output_shape.mutable_layout() = diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.h b/third_party/xla/xla/service/llvm_ir/ir_array.h index f5b2b7c4fbd792..2f4d6be9794f19 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_array.h +++ b/third_party/xla/xla/service/llvm_ir/ir_array.h @@ -140,8 +140,9 @@ class IrArray { } Shape AsShapeWithType(PrimitiveType element_type) const { - return ShapeUtil::MakeShapeWithDenseLayout(element_type, dims_, - layout_.minor_to_major()); + return ShapeUtil::MakeShapeWithDenseLayout( + element_type, dims_, ShapeUtil::MakeConstantExpressions(dims_), + layout_.minor_to_major()); } // Given that "this" is the target index of a reshape from `input_shape` diff --git a/third_party/xla/xla/shape_test.cc b/third_party/xla/xla/shape_test.cc index fb51518b116eb2..14f38d3cb3f82c 100644 --- a/third_party/xla/xla/shape_test.cc +++ b/third_party/xla/xla/shape_test.cc @@ -30,6 +30,15 @@ limitations under the License. namespace xla { namespace { +std::vector MakeExprs(absl::Span dimensions) { + std::vector expressions; + expressions.reserve(dimensions.size()); + for (int64_t d : dimensions) { + expressions.push_back(DynExpr::_(d)); + } + return expressions; +} + class ShapeTest : public ::testing::Test { protected: const Shape opaque_ = ShapeUtil::MakeOpaqueShape(); @@ -40,7 +49,8 @@ class ShapeTest : public ::testing::Test { const Shape matrix_ = ShapeUtil::MakeShape(U32, {1, 2}); const Shape matrix2_ = ShapeUtil::MakeShapeWithDenseLayout(S32, {3, 4}, {0, 1}); - const Shape matrix_buffer_ = ShapeUtil::MakeBufferShape(S32, {3, 4}); + const Shape matrix_buffer_ = + ShapeUtil::MakeBufferShape(S32, {3, 4}, MakeExprs({3, 4})); const Shape tuple_ = ShapeUtil::MakeTupleShape({opaque_, scalar_, matrix_, matrix2_}); const Shape nested_tuple_ = @@ -141,9 +151,11 @@ TEST_F(ShapeTest, EqualityTest) { // Equal with Buffer shapes. EXPECT_TRUE( - Shape::Equal().IgnoreBuffer()(ShapeUtil::MakeBufferShape(S32, {3, 4}), + Shape::Equal().IgnoreBuffer()(ShapeUtil::MakeBufferShape( + S32, {3, 4}, MakeExprs({3, 4})), ShapeUtil::MakeShape(S32, {3, 4}))); - EXPECT_FALSE(Shape::Equal()(ShapeUtil::MakeBufferShape(S32, {3, 4}), + EXPECT_FALSE(Shape::Equal()(ShapeUtil::MakeBufferShape( + S32, {3, 4}, MakeExprs({3, 4})), ShapeUtil::MakeShape(S32, {3, 4}))); } diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index 2a72f66d943dfd..db5d67cd731ab9 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -123,8 +123,9 @@ void PrintBufferShape(Printer* printer, const Shape& shape) { // its Layout. absl::StatusOr MakeShapeWithLayoutInternal( PrimitiveType element_type, absl::Span dimensions, - absl::Span minor_to_major, - absl::Span tiles, int64_t tail_padding_alignment_in_elements, + absl::Span expressions, + absl::Span minor_to_major, absl::Span tiles, + int64_t tail_padding_alignment_in_elements, PrimitiveType index_primitive_type, PrimitiveType pointer_primitive_type, int64_t element_size_in_bits, int64_t memory_space, absl::Span split_configs, @@ -138,8 +139,8 @@ absl::StatusOr MakeShapeWithLayoutInternal( return InvalidArgument("Unsupported element type: %s", PrimitiveType_Name(element_type)); } - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeUtil::MakeValidatedShape(element_type, dimensions)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::MakeValidatedShape( + element_type, dimensions, expressions)); if (element_size_in_bits == ShapeUtil::ByteSizeOfPrimitiveType(element_type) * 8) { // Only set element_size_in_bits if it's different from the default value. @@ -267,7 +268,7 @@ static std::vector MakeDynamicDimensions( return dynamic_dimensions; } -static std::vector MakeExpressions( +/* static */ std::vector ShapeUtil::MakeConstantExpressions( absl::Span dimensions) { std::vector expressions; expressions.reserve(dimensions.size()); @@ -284,7 +285,7 @@ static std::vector MakeExpressions( } /* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) { - return MakeShape(element_type, {}); + return MakeShape(element_type, {}, {}); } /* static */ Shape ShapeUtil::MakeShape( @@ -297,8 +298,10 @@ static std::vector MakeExpressions( } /* static */ Shape ShapeUtil::MakeBufferShape( - PrimitiveType element_type, absl::Span dimensions) { - return Shape::MakeBufferShape(MakeShape(element_type, dimensions)); + PrimitiveType element_type, absl::Span dimensions, + absl::Span expressions) { + return Shape::MakeBufferShape( + MakeShape(element_type, dimensions, expressions)); } /* static */ Shape ShapeUtil::MakeShapeWithStaticDimensions( @@ -311,9 +314,8 @@ static std::vector MakeExpressions( /* static */ absl::StatusOr ShapeUtil::MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions, absl::Span expressions) { - return MakeValidatedShape( - element_type, dimensions, MakeDynamicDimensions(dimensions), - expressions.empty() ? MakeExpressions(dimensions) : expressions); + return MakeValidatedShape(element_type, dimensions, + MakeDynamicDimensions(dimensions), expressions); } /* static */ absl::StatusOr ShapeUtil::MakeValidatedShape( @@ -373,11 +375,12 @@ static std::vector MakeExpressions( /* static */ Shape ShapeUtil::MakeShapeWithDenseLayout( PrimitiveType element_type, absl::Span dimensions, + absl::Span expressions, absl::Span minor_to_major, absl::Span tiles, int64_t tail_padding_alignment_in_elements, int64_t element_size_in_bits, int64_t memory_space, absl::Span split_configs) { auto ret = MakeShapeWithLayoutInternal( - element_type, dimensions, minor_to_major, tiles, + element_type, dimensions, expressions, minor_to_major, tiles, tail_padding_alignment_in_elements, /*index_primitive_type=*/PRIMITIVE_TYPE_INVALID, /*pointer_primitive_type=*/PRIMITIVE_TYPE_INVALID, element_size_in_bits, @@ -389,12 +392,13 @@ static std::vector MakeExpressions( /* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( PrimitiveType element_type, absl::Span dimensions, + absl::Span expressions, absl::Span minor_to_major, PrimitiveType index_primitive_type, PrimitiveType pointer_primitive_type, int64_t tail_padding_alignment_in_elements, int64_t element_size_in_bits, int64_t memory_space, std::optional physical_shape) { auto ret = MakeShapeWithLayoutInternal( - element_type, dimensions, minor_to_major, + element_type, dimensions, expressions, minor_to_major, /*tiles=*/{}, tail_padding_alignment_in_elements, index_primitive_type, pointer_primitive_type, element_size_in_bits, memory_space, /*split_configs=*/{}, std::move(physical_shape)); @@ -433,10 +437,8 @@ static std::vector MakeExpressions( absl::Span expressions) { std::vector layout(dimensions.size()); std::iota(layout.rbegin(), layout.rend(), static_cast(0)); - auto shape = MakeShapeWithDenseLayout(element_type, dimensions, layout); - std::vector exprs(expressions.begin(), expressions.end()); - shape.set_expressions(exprs); - return shape; + return MakeShapeWithDenseLayout(element_type, dimensions, expressions, + layout); } /* static */ Shape @@ -450,7 +452,16 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } dims[i] = shape.dimensions(dim); } - Shape new_shape = MakeShapeWithDescendingLayout(shape.element_type(), dims); + std::vector expressions(shape.dimensions().size()); + for (int i = 0; i < shape.dimensions().size(); ++i) { + int dim = i; + if (shape.has_layout()) { + dim = LayoutUtil::Major(shape.layout(), dim); + } + expressions[i] = shape.expressions(dim); + } + Shape new_shape = + MakeShapeWithDescendingLayout(shape.element_type(), dims, expressions); // Since the physical layout is kept the same, the tiles and element size are // the same also. if (shape.has_layout()) { @@ -1826,7 +1837,8 @@ ShapeUtil::DecomposeBitcastToTrt(const Shape& input_shape, } Shape output_shape_with_layout = MakeShapeWithDenseLayout( - output_shape.element_type(), output_shape.dimensions(), output_layout); + output_shape.element_type(), output_shape.dimensions(), + output_shape.get_expressions(), output_layout); CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout)) << "reshape is not a bitcast for input_shape: " << ShapeUtil::HumanStringWithLayout(input_shape) @@ -2059,7 +2071,8 @@ struct ParallelState { } // Create the shape of the "work" which has same layout as the original shape. - Shape work_shape = ShapeUtil::MakeShape(shape.element_type(), work_dims); + Shape work_shape = ShapeUtil::MakeShape(shape.element_type(), work_dims, + shape.get_expressions()); *work_shape.mutable_layout() = shape.layout(); // We target one task (partition) per available thread. diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h index 7e9ed98719ff4f..f9f96fd4623ec3 100644 --- a/third_party/xla/xla/shape_util.h +++ b/third_party/xla/xla/shape_util.h @@ -407,11 +407,12 @@ class ShapeUtil { return shape.element_type() != PRIMITIVE_TYPE_INVALID; } - // Constructs a new shape with the given element type and sequence of - // dimensions. + // Constructs a new shape with the given element type, dimensions, and + // per-dimension expressions. `dimensions` and `expressions` must have the + // same size. static Shape MakeShape(PrimitiveType element_type, absl::Span dimensions, - absl::Span expressions = {}); + absl::Span expressions); // Make a scalar shape with given primitive type. static Shape MakeScalarShape(PrimitiveType element_type); @@ -426,47 +427,58 @@ class ShapeUtil { absl::Span dimensions, const std::vector& dynamic_dimensions, absl::Span expressions); - // Constructs a new buffer shape with the given element type, and sequence of - // dimensions. + // Constructs a new buffer shape with the given element type, dimensions, and + // per-dimension expressions. `dimensions` and `expressions` must have the + // same size. static Shape MakeBufferShape(PrimitiveType element_type, - absl::Span dimensions); - - // Constructs a new shape with the given element type and sequence of - // dimensions. Method checks if the element type is valid, the shape's - // size fits in std::numeric_limits::max(), and dynamic size is not - // marked static. + absl::Span dimensions, + absl::Span expressions); + + // Constructs a new shape with the given element type, dimensions, and + // per-dimension expressions. Method checks if the element type is valid, the + // shape's size fits in std::numeric_limits::max(), and dynamic size + // is not marked static. `dimensions` and `expressions` must have the same + // size. static absl::StatusOr MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions, - absl::Span expressions = {}); + absl::Span expressions); + // As above, but also accepts an explicit dynamic-dimension mask. All three + // inputs must have the same size. static absl::StatusOr MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions, const std::vector& dynamic_dimensions, - absl::Span expressions = {}); + absl::Span expressions); - // Creates a Shape with element type corresponding to T and the given - // dimensions + // Creates a Shape with element type corresponding to T, the given + // dimensions, and the given per-dimension expressions. template - static Shape MakeShapeWithType(absl::Span dimensions) { + static Shape MakeShapeWithType(absl::Span dimensions, + absl::Span expressions) { return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), - dimensions); + dimensions, expressions); } - // Constructs a new dense array shape with the given minor_to_major order in - // its Layout. Returns a value shape such that shape.has_layout(). + // Constructs a new dense array shape with the given dimensions, + // per-dimension expressions, and minor_to_major order in its Layout. + // `dimensions` and `expressions` must have the same size. Returns a value + // shape such that shape.has_layout(). static Shape MakeShapeWithDenseLayout( PrimitiveType element_type, absl::Span dimensions, + absl::Span expressions, absl::Span minor_to_major, absl::Span tiles = {}, int64_t tail_padding_alignment_in_elements = 1, int64_t element_size_in_bits = 0, int64_t memory_space = 0, absl::Span split_configs = {}); - // Constructs a new sparse array shape with the given minor_to_major order - // in its Layout. Returns a value shape such that - // shape.has_layout(). + // Constructs a new sparse array shape with the given dimensions, + // per-dimension expressions, and minor_to_major order in its Layout. + // `dimensions` and `expressions` must have the same size. Returns a value + // shape such that shape.has_layout(). static Shape MakeShapeWithSparseLayout( PrimitiveType element_type, absl::Span dimensions, + absl::Span expressions, absl::Span minor_to_major, PrimitiveType index_primitive_type = PRIMITIVE_TYPE_INVALID, PrimitiveType pointer_primitive_type = PRIMITIVE_TYPE_INVALID, @@ -483,10 +495,11 @@ class ShapeUtil { // Returns the same shape except with all dimensions set to be static. static Shape MakeShapeWithStaticDimensions(const Shape& shape); - // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}). + // Constructs a new shape with the given dimensions, per-dimension + // expressions, and major-first layout (i.e. {n, n-1, ..., 0}). static Shape MakeShapeWithDescendingLayout( PrimitiveType element_type, absl::Span dimensions, - absl::Span expressions = {}); + absl::Span expressions); // Returns a new Shape based on the given Shape with low-dimension-major // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions @@ -525,6 +538,11 @@ class ShapeUtil { // Returns whether the element type of the shape is complex. static bool ElementIsComplex(const Shape& shape); + // Builds one constant expression per dimension, mirroring the supplied + // dimension sizes. + static std::vector MakeConstantExpressions( + absl::Span dimensions); + // Returns whether the element type has the given bit width. static bool ElementHasBitWidth(const Shape& shape, int bits); From 5dc9c0a5c6775feb5bb633ad91d7fdb7ec55ad0a Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Sun, 22 Mar 2026 00:09:02 +0000 Subject: [PATCH 2/8] Add default-expression dense shape overload --- third_party/xla/xla/shape_util.cc | 11 +++++++++++ third_party/xla/xla/shape_util.h | 10 ++++++++++ 2 files changed, 21 insertions(+) diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index db5d67cd731ab9..532653652d2512 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -390,6 +390,17 @@ static std::vector MakeDynamicDimensions( return *ret; } +/* static */ Shape ShapeUtil::MakeShapeWithDenseLayout( + PrimitiveType element_type, absl::Span dimensions, + absl::Span minor_to_major, absl::Span tiles, + int64_t tail_padding_alignment_in_elements, int64_t element_size_in_bits, + int64_t memory_space, absl::Span split_configs) { + return MakeShapeWithDenseLayout( + element_type, dimensions, MakeConstantExpressions(dimensions), + minor_to_major, tiles, tail_padding_alignment_in_elements, + element_size_in_bits, memory_space, split_configs); +} + /* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( PrimitiveType element_type, absl::Span dimensions, absl::Span expressions, diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h index f9f96fd4623ec3..ef5e57c0b0e985 100644 --- a/third_party/xla/xla/shape_util.h +++ b/third_party/xla/xla/shape_util.h @@ -472,6 +472,16 @@ class ShapeUtil { int64_t element_size_in_bits = 0, int64_t memory_space = 0, absl::Span split_configs = {}); + // Convenience overload for the common static case where expressions mirror + // the provided dimensions. + static Shape MakeShapeWithDenseLayout( + PrimitiveType element_type, absl::Span dimensions, + absl::Span minor_to_major, + absl::Span tiles = {}, + int64_t tail_padding_alignment_in_elements = 1, + int64_t element_size_in_bits = 0, int64_t memory_space = 0, + absl::Span split_configs = {}); + // Constructs a new sparse array shape with the given dimensions, // per-dimension expressions, and minor_to_major order in its Layout. // `dimensions` and `expressions` must have the same size. Returns a value From 3b7b27d8a701599c754608b702d4604bd25407fe Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Sun, 22 Mar 2026 00:20:25 +0000 Subject: [PATCH 3/8] Add default-expression buffer shape overload --- .../triton_xla_extract_insert_to_triton_pass.cc | 4 +--- third_party/xla/xla/hlo/parser/hlo_parser_test.cc | 5 ++--- third_party/xla/xla/literal.cc | 4 +--- third_party/xla/xla/literal_util.h | 4 ---- third_party/xla/xla/pjrt/utils.cc | 3 +-- third_party/xla/xla/service/gpu/matmul_utils.cc | 3 +-- .../xla/service/gpu/transforms/horizontal_loop_fusion.cc | 2 -- .../service/gpu/transforms/reduction_layout_normalizer.cc | 1 - third_party/xla/xla/service/llvm_ir/ir_array.h | 3 +-- third_party/xla/xla/shape_test.cc | 8 +++----- third_party/xla/xla/shape_util.cc | 6 ++++++ third_party/xla/xla/shape_util.h | 4 ++++ 12 files changed, 20 insertions(+), 27 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc index c3fab46b2a3ec5..4b6d01a662f0fc 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc @@ -216,9 +216,7 @@ Value ComputeLinearOffset(::xla::EmitterLocOpBuilder& builder, ValueRange offsets, llvm::ArrayRef layout) { ::xla::Shape shape = ::xla::ShapeUtil::MakeShapeWithDenseLayout( xgt::GetPrimitiveType(tensor_type.getElementType()).value(), - tensor_type.getShape(), - ::xla::ShapeUtil::MakeConstantExpressions(tensor_type.getShape()), - layout); + tensor_type.getShape(), layout); ::xla::Shape linear_shape = ::xla::ShapeUtil::MakeShape( shape.element_type(), {::xla::ShapeUtil::ElementsIn(shape)}); diff --git a/third_party/xla/xla/hlo/parser/hlo_parser_test.cc b/third_party/xla/xla/hlo/parser/hlo_parser_test.cc index a5450e52785344..f49e109a69f1e2 100644 --- a/third_party/xla/xla/hlo/parser/hlo_parser_test.cc +++ b/third_party/xla/xla/hlo/parser/hlo_parser_test.cc @@ -5886,7 +5886,7 @@ TEST_F(HloParserTest, ParseBufferMoreThanOneElement) { TEST_F(HloParserTest, ParseBufferScalar) { std::string shape_string = "b(s32[])"; TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); - Shape expected = ShapeUtil::MakeBufferShape(S32, {}, {}); + Shape expected = ShapeUtil::MakeBufferShape(S32, {}); ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) << "expected: " << ShapeUtil::HumanString(expected) << "actual: " << ShapeUtil::HumanString(actual); @@ -5895,8 +5895,7 @@ TEST_F(HloParserTest, ParseBufferScalar) { TEST_F(HloParserTest, ParseBufferArray) { std::string shape_string = "b(f32[8,16]{1,0})"; TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); - std::vector expressions = {DynExpr::_(8), DynExpr::_(16)}; - Shape expected = ShapeUtil::MakeBufferShape(F32, {8, 16}, expressions); + Shape expected = ShapeUtil::MakeBufferShape(F32, {8, 16}); ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) << "expected: " << ShapeUtil::HumanString(expected) << "actual: " << ShapeUtil::HumanString(actual); diff --git a/third_party/xla/xla/literal.cc b/third_party/xla/xla/literal.cc index 4dc62679461e50..b505e68562257f 100644 --- a/third_party/xla/xla/literal.cc +++ b/third_party/xla/xla/literal.cc @@ -1302,9 +1302,7 @@ Literal LiteralBase::Slice(absl::Span start_indices, result_dimensions.push_back(dimension); } auto result_shape = ShapeUtil::MakeShapeWithDenseLayout( - shape().element_type(), result_dimensions, - ShapeUtil::MakeConstantExpressions(result_dimensions), - LayoutUtil::MinorToMajor(shape())); + shape().element_type(), result_dimensions, LayoutUtil::MinorToMajor(shape())); ShapeUtil::CopyDynamicDimensions(&result_shape, shape()); Literal result_literal(result_shape); primitive_util::ArrayTypeSwitch( diff --git a/third_party/xla/xla/literal_util.h b/third_party/xla/xla/literal_util.h index 076725fad64865..4f8568110bca32 100644 --- a/third_party/xla/xla/literal_util.h +++ b/third_party/xla/xla/literal_util.h @@ -351,9 +351,6 @@ template primitive_util::NativeToPrimitiveType(), {static_cast(values.size()), static_cast(values.begin()->size())}, - ShapeUtil::MakeConstantExpressions( - {static_cast(values.size()), - static_cast(values.begin()->size())}), layout.minor_to_major())); literal.PopulateR2(values); return literal; @@ -441,7 +438,6 @@ template const Array& values, const Layout& layout) { Literal literal(ShapeUtil::MakeShapeWithDenseLayout( primitive_util::NativeToPrimitiveType(), values.dimensions(), - ShapeUtil::MakeConstantExpressions(values.dimensions()), layout.minor_to_major())); literal.PopulateFromArray(values); return literal; diff --git a/third_party/xla/xla/pjrt/utils.cc b/third_party/xla/xla/pjrt/utils.cc index 87828f0f429b4d..7c203dc6cc101a 100644 --- a/third_party/xla/xla/pjrt/utils.cc +++ b/third_party/xla/xla/pjrt/utils.cc @@ -877,9 +877,8 @@ absl::StatusOr MakeShapeWithTrivialByteStrides( byte_stride *= dimensions[d]; } } - auto expressions = ShapeUtil::MakeConstantExpressions(dimensions); return ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, - expressions, minor_to_major); + minor_to_major); } absl::Status TestBufferDonationClashes( diff --git a/third_party/xla/xla/service/gpu/matmul_utils.cc b/third_party/xla/xla/service/gpu/matmul_utils.cc index 33abdbd887ea47..929869ce2d9162 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils.cc +++ b/third_party/xla/xla/service/gpu/matmul_utils.cc @@ -107,8 +107,7 @@ absl::StatusOr GetBatchRowColumnShape( std::vector dimensions = {dim_size(batch_dims), dim_size(row_dims), dim_size(col_dims)}; return ShapeUtil::MakeShapeWithDenseLayout( - shape.element_type(), dimensions, - ShapeUtil::MakeConstantExpressions(dimensions), minor_to_major); + shape.element_type(), dimensions, minor_to_major); } // Returns the matrix layout for a logical shape (batch, rows, columns). diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc index ed725a2da74ca9..6f3d1d91144224 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc @@ -568,8 +568,6 @@ absl::Status HorizontalLoopFusionImpl::CreateFusedComputation( Shape new_shape = ShapeUtil::MakeShapeWithDenseLayout( new_output->shape().element_type(), {ShapeUtil::ElementsIn(new_output->shape())}, - ShapeUtil::MakeConstantExpressions( - {ShapeUtil::ElementsIn(new_output->shape())}), /*minor_to_major=*/std::vector(1, 0)); TF_ASSIGN_OR_RETURN(instr_outputs[j], MakeReshapeHlo(new_shape, new_output)); diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc index 12b00e78c30f49..38b8878eea128e 100644 --- a/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc @@ -123,7 +123,6 @@ class EnforceMinorToMajorReduceOpVisitor : public DfsHloRewriteVisitor { operand_shape.element_type(), new_operand_shape_data); Shape new_reduce_shape = ShapeUtil::MakeShapeWithDenseLayout( reduce_shape.element_type(), new_reduce_shape_data, - ShapeUtil::MakeConstantExpressions(new_reduce_shape_data), new_reduce_shape_layout); if (new_operand_shape == operand_shape && reduce->inputs().size() == 1) { diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.h b/third_party/xla/xla/service/llvm_ir/ir_array.h index 2f4d6be9794f19..1e75f2a0e938fd 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_array.h +++ b/third_party/xla/xla/service/llvm_ir/ir_array.h @@ -141,8 +141,7 @@ class IrArray { Shape AsShapeWithType(PrimitiveType element_type) const { return ShapeUtil::MakeShapeWithDenseLayout( - element_type, dims_, ShapeUtil::MakeConstantExpressions(dims_), - layout_.minor_to_major()); + element_type, dims_, layout_.minor_to_major()); } // Given that "this" is the target index of a reshape from `input_shape` diff --git a/third_party/xla/xla/shape_test.cc b/third_party/xla/xla/shape_test.cc index 14f38d3cb3f82c..fd92cb7a9c2735 100644 --- a/third_party/xla/xla/shape_test.cc +++ b/third_party/xla/xla/shape_test.cc @@ -50,7 +50,7 @@ class ShapeTest : public ::testing::Test { const Shape matrix2_ = ShapeUtil::MakeShapeWithDenseLayout(S32, {3, 4}, {0, 1}); const Shape matrix_buffer_ = - ShapeUtil::MakeBufferShape(S32, {3, 4}, MakeExprs({3, 4})); + ShapeUtil::MakeBufferShape(S32, {3, 4}); const Shape tuple_ = ShapeUtil::MakeTupleShape({opaque_, scalar_, matrix_, matrix2_}); const Shape nested_tuple_ = @@ -151,11 +151,9 @@ TEST_F(ShapeTest, EqualityTest) { // Equal with Buffer shapes. EXPECT_TRUE( - Shape::Equal().IgnoreBuffer()(ShapeUtil::MakeBufferShape( - S32, {3, 4}, MakeExprs({3, 4})), + Shape::Equal().IgnoreBuffer()(ShapeUtil::MakeBufferShape(S32, {3, 4}), ShapeUtil::MakeShape(S32, {3, 4}))); - EXPECT_FALSE(Shape::Equal()(ShapeUtil::MakeBufferShape( - S32, {3, 4}, MakeExprs({3, 4})), + EXPECT_FALSE(Shape::Equal()(ShapeUtil::MakeBufferShape(S32, {3, 4}), ShapeUtil::MakeShape(S32, {3, 4}))); } diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index 532653652d2512..ec79321bbeb878 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -304,6 +304,12 @@ static std::vector MakeDynamicDimensions( MakeShape(element_type, dimensions, expressions)); } +/* static */ Shape ShapeUtil::MakeBufferShape( + PrimitiveType element_type, absl::Span dimensions) { + return MakeBufferShape(element_type, dimensions, + MakeConstantExpressions(dimensions)); +} + /* static */ Shape ShapeUtil::MakeShapeWithStaticDimensions( const Shape& shape) { Shape output = shape; diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h index ef5e57c0b0e985..e9416fc9911041 100644 --- a/third_party/xla/xla/shape_util.h +++ b/third_party/xla/xla/shape_util.h @@ -433,6 +433,10 @@ class ShapeUtil { static Shape MakeBufferShape(PrimitiveType element_type, absl::Span dimensions, absl::Span expressions); + // Convenience overload for fully static buffer shapes. Expressions are + // derived from `dimensions`. + static Shape MakeBufferShape(PrimitiveType element_type, + absl::Span dimensions); // Constructs a new shape with the given element type, dimensions, and // per-dimension expressions. Method checks if the element type is valid, the From 06972939ce6656955968712e392700f31df7f01d Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Sun, 22 Mar 2026 00:27:38 +0000 Subject: [PATCH 4/8] Preserve symbolic expressions in tf2xla shapes --- tensorflow/compiler/tf2xla/kernels/bincount_op.cc | 9 ++++++--- tensorflow/compiler/tf2xla/kernels/resampler_ops.cc | 8 +++++--- .../compiler/tf2xla/kernels/tensor_list_utils.cc | 12 ++++++++++-- tensorflow/compiler/tf2xla/kernels/where_op.cc | 8 +++++--- third_party/xla/xla/shape_util.cc | 5 +++++ third_party/xla/xla/shape_util.h | 4 ++++ 6 files changed, 35 insertions(+), 11 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc index 4f31c79f91a719..3eef31467c319b 100644 --- a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc @@ -110,11 +110,14 @@ class DenseBincountOp : public XlaOpKernel { scatter_dnums.add_scatter_dims_to_operand_dims(0); if (rank == 2) { - output_shape = xla::ShapeUtil::MakeShape(dtype, {size, output_size}); + output_shape = xla::ShapeUtil::MakeShape( + dtype, {size, output_size}, + {input_shape.expressions(0), xla::DynExpr::_(output_size)}); scatter_dnums.add_inserted_window_dims(1); scatter_dnums.add_scatter_dims_to_operand_dims(1); - auto i_shape = - xla::ShapeUtil::MakeShape(input_xla_type, {input_shape.dimensions()}); + auto i_shape = xla::ShapeUtil::MakeShape(input_xla_type, + input_shape.dimensions(), + input_shape.expressions()); auto i = xla::Iota(ctx->builder(), i_shape, 0); i = xla::Reshape( i, {input_shape.dimensions(0) * input_shape.dimensions(1), 1}, diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc index c54c4613d29e44..6cba8717d93a47 100644 --- a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc @@ -122,12 +122,14 @@ XlaOp ConcatenateIota(xla::XlaBuilder* b, XlaOp indices, for (auto dim : warp_shape) { dimensions.push_back(dim.size); } + auto expressions = warp_shape.get_expressions(); // Except the last dimension, which is of size 1. dimensions.back() = 1; + expressions.back() = xla::DynExpr::one; - auto batch_indices = - xla::Iota(b, xla::ShapeUtil::MakeShape(xla::S32, dimensions), - /*iota_dimension=*/0); + auto batch_indices = xla::Iota( + b, xla::ShapeUtil::MakeShape(xla::S32, dimensions, expressions), + /*iota_dimension=*/0); return xla::ConcatInDim(b, {batch_indices, indices}, dimensions.size() - 1); } diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc index 1771181e440f31..9f20fde8a57373 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -241,9 +241,13 @@ absl::Status GetTensorListShapeFromElementTensorListShape( const xla::Shape& shape = xla::ShapeUtil::GetTupleElementShape(element_tensor_list_shape, i); std::vector dimensions = xla::SpanToVector(shape.dimensions()); + std::vector expressions = + xla::SpanToVector(shape.expressions()); dimensions.insert(dimensions.begin(), leading_dim); + expressions.insert(expressions.begin(), xla::DynExpr::_(leading_dim)); shapes.push_back( - xla::ShapeUtil::MakeShape(shape.element_type(), dimensions)); + xla::ShapeUtil::MakeShape(shape.element_type(), dimensions, + expressions)); if (leading_dim_is_dynamic) { shapes.back().set_dynamic_dimension(0, true); } @@ -267,9 +271,13 @@ absl::Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, std::vector shapes; std::vector dimensions = xla::SpanToVector(element_shape.dimensions()); + std::vector expressions = + xla::SpanToVector(element_shape.expressions()); dimensions.insert(dimensions.begin(), leading_dim); + expressions.insert(expressions.begin(), xla::DynExpr::_(leading_dim)); shapes.push_back( - xla::ShapeUtil::MakeShape(element_shape.element_type(), dimensions)); + xla::ShapeUtil::MakeShape(element_shape.element_type(), dimensions, + expressions)); shapes.back().set_dynamic_dimension(0, leading_dim_is_dynamic); shapes.push_back(xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector{})); diff --git a/tensorflow/compiler/tf2xla/kernels/where_op.cc b/tensorflow/compiler/tf2xla/kernels/where_op.cc index 29bcf4fa0769ad..733292744029c0 100644 --- a/tensorflow/compiler/tf2xla/kernels/where_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/where_op.cc @@ -161,8 +161,8 @@ absl::StatusOr CompileWhereWithSort(XlaOpKernelContext* ctx) { XlaOp condition = ctx->Input(0); TF_ASSIGN_OR_RETURN(xla::Shape input_shape, ctx->builder()->GetShape(condition)); - auto iota_shape = - xla::ShapeUtil::MakeShape(xla::S32, input_shape.dimensions()); + auto iota_shape = xla::ShapeUtil::MakeShape(xla::S32, input_shape.dimensions(), + input_shape.expressions()); int64_t flattened_size = xla::Product(iota_shape.dimensions()); xla::DynExpr* flattened_expr = xla::DynExpr::one; @@ -290,7 +290,9 @@ absl::StatusOr CompileWhereWithPrefixSum(XlaOpKernelContext* ctx) { // // and then scatter iotas[out_idxs] into the output. std::vector iotas_to_concat; - auto iota_shape = xla::ShapeUtil::MakeShape(S32, input_shape.dimensions()); + auto iota_shape = + xla::ShapeUtil::MakeShape(S32, input_shape.dimensions(), + input_shape.expressions()); iotas_to_concat.reserve(iota_shape.dimensions_size()); for (int64_t axis = 0; axis < iota_shape.dimensions_size(); ++axis) { iotas_to_concat.push_back( diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index ec79321bbeb878..db0a73dcc1c152 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -284,6 +284,11 @@ static std::vector MakeDynamicDimensions( return MakeValidatedShape(element_type, dimensions, expressions).value(); } +/* static */ Shape ShapeUtil::MakeShape( + PrimitiveType element_type, absl::Span dimensions) { + return MakeShape(element_type, dimensions, MakeConstantExpressions(dimensions)); +} + /* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) { return MakeShape(element_type, {}, {}); } diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h index e9416fc9911041..2ba48309d9976b 100644 --- a/third_party/xla/xla/shape_util.h +++ b/third_party/xla/xla/shape_util.h @@ -413,6 +413,10 @@ class ShapeUtil { static Shape MakeShape(PrimitiveType element_type, absl::Span dimensions, absl::Span expressions); + // Convenience overload for fully static shapes. Expressions are derived from + // `dimensions`. + static Shape MakeShape(PrimitiveType element_type, + absl::Span dimensions); // Make a scalar shape with given primitive type. static Shape MakeScalarShape(PrimitiveType element_type); From 2bdcc5f030f914e278e32cf4ac9a10baefea1826 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Sun, 22 Mar 2026 00:31:23 +0000 Subject: [PATCH 5/8] Use default expressions for unknown tf2xla shapes --- tensorflow/compiler/tf2xla/shape_util.cc | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index e516763482f58a..ca265bf39a36f7 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -123,8 +123,7 @@ absl::Status TensorShapeToBoundedXLAShape( TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); if (tensor_shape.unknown_rank()) { // For unknown shape, create a rank 1 size 0 tensor. - *shape = xla::ShapeUtil::MakeShapeWithDenseLayout( - type, {0}, {xla::DynExpr::_(0)}, {0}); + *shape = xla::ShapeUtil::MakeShapeWithDenseLayout(type, {0}, {0}); return absl::OkStatus(); } @@ -174,8 +173,7 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, const PartialTensorShape& tensor_shape) { if (tensor_shape.unknown_rank()) { // For unknown shape, create a rank 1 size 0 tensor. - return xla::ShapeUtil::MakeShapeWithDenseLayout( - type, {0}, {xla::DynExpr::_(0)}, {0}); + return xla::ShapeUtil::MakeShapeWithDenseLayout(type, {0}, {0}); } int rank = tensor_shape.dims(); std::vector dimensions(rank); @@ -186,8 +184,7 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, if (dimensions[d] < 0) { LOG(WARNING) << "Unable to convert TF shape with dynamic size to XLA " "shape; returning unknown sentinel value"; - return xla::ShapeUtil::MakeShapeWithDenseLayout( - type, {0}, {xla::DynExpr::_(0)}, {0}); + return xla::ShapeUtil::MakeShapeWithDenseLayout(type, {0}, {0}); } expressions[d] = tensor_shape.get_expression(d); } From ecce5ea11071f1f3cf0a3719e97336bc2e4832f4 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Sun, 22 Mar 2026 00:34:26 +0000 Subject: [PATCH 6/8] Inline dense matmul shape dimensions again --- third_party/xla/xla/service/gpu/matmul_utils.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/matmul_utils.cc b/third_party/xla/xla/service/gpu/matmul_utils.cc index 929869ce2d9162..7a203388fd8893 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils.cc +++ b/third_party/xla/xla/service/gpu/matmul_utils.cc @@ -104,10 +104,10 @@ absl::StatusOr GetBatchRowColumnShape( }); }; - std::vector dimensions = {dim_size(batch_dims), dim_size(row_dims), - dim_size(col_dims)}; return ShapeUtil::MakeShapeWithDenseLayout( - shape.element_type(), dimensions, minor_to_major); + shape.element_type(), + {dim_size(batch_dims), dim_size(row_dims), dim_size(col_dims)}, + minor_to_major); } // Returns the matrix layout for a logical shape (batch, rows, columns). From e12c57741cac4646f2ddfdf56deef4049e8edcb4 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Sun, 22 Mar 2026 00:42:46 +0000 Subject: [PATCH 7/8] Preserve more symbolic tf2xla shape expressions --- .../tf2xla/kernels/dynamic_partition_op.cc | 16 +++++++++++++--- .../compiler/tf2xla/kernels/dynamic_stitch_op.cc | 7 ++++++- .../compiler/tf2xla/kernels/resampler_ops.cc | 9 +++++++-- third_party/xla/xla/shape_util.cc | 6 ++++++ third_party/xla/xla/shape_util.h | 12 ++++++++++++ 5 files changed, 44 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc index 84c091339c1804..e793dc96870fde 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc @@ -43,6 +43,14 @@ limitations under the License. namespace tensorflow { namespace { +xla::DynExpr* ProductOfExpressions(absl::Span expressions) { + xla::DynExpr* product = xla::DynExpr::one; + for (xla::DynExpr* expr : expressions) { + product = (*product * *expr).s(); + } + return product; +} + class DynamicPartitionOp : public XlaOpKernel { public: explicit DynamicPartitionOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -162,11 +170,13 @@ class DynamicPartitionOp : public XlaOpKernel { int64_t input_count = xla::ShapeUtil::ElementsIn(data_shape); auto data_1d = xla::Reshape(data, {input_count}); auto partitions_1d = xla::Reshape(partitions, {input_count}); - xla::Shape data_1d_shape = - xla::ShapeUtil::MakeShape(data_shape.element_type(), {input_count}); + xla::Shape data_1d_shape = xla::ShapeUtil::MakeShape( + data_shape.element_type(), {input_count}, + {ProductOfExpressions(data_shape.expressions())}); xla::Shape partitions_1d_shape = xla::ShapeUtil::MakeShape( - partition_shape.element_type(), {input_count}); + partition_shape.element_type(), {input_count}, + {ProductOfExpressions(partition_shape.expressions())}); std::vector output, partition_length; std::tie(output, partition_length) = DynamicPartition1D( diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 6674fc6cde0793..ff7cdd2dfa63d4 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -132,13 +132,18 @@ class DynamicStitchOp : public XlaOpKernel { int64_t result_rank = 1 + data0_shape.dims() - indices0_shape.dims(); if (number_of_indices == 0) { std::vector result_shape(result_rank); + std::vector result_expressions(result_rank, + xla::DynExpr::zero); for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { result_shape[d - indices0_shape.dims() + 1] = data0_shape.dim_size(d); + result_expressions[d - indices0_shape.dims() + 1] = + data0_shape.get_expression(d); } xla::PrimitiveType element_type = ctx->input_xla_type(ctx->num_inputs() - 1); xla::Literal empty_literal = xla::Literal::CreateFromShape( - xla::ShapeUtil::MakeShape(element_type, result_shape)); + xla::ShapeUtil::MakeShape(element_type, result_shape, + result_expressions)); ctx->SetOutput(0, xla::ConstantLiteral(ctx->builder(), empty_literal)); return; } diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc index 6cba8717d93a47..7ed5644f0fe181 100644 --- a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc @@ -367,14 +367,19 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, auto warp_dims = warp_shape.dim_sizes(); std::vector warp_dims_without_last_dims(warp_dims.begin(), warp_dims.end() - 1); + auto warp_expressions = warp_shape.get_expressions(); + std::vector warp_expressions_without_last_dim( + warp_expressions.begin(), warp_expressions.end() - 1); // With dimension [batch, dim_0, ...dim_n, 4] std::vector neighbor_broadcast_dims = warp_dims_without_last_dims; neighbor_broadcast_dims.push_back(4); + auto neighbor_broadcast_expressions = warp_expressions_without_last_dim; + neighbor_broadcast_expressions.push_back(xla::DynExpr::_(4)); // With dimension [batch, dim_0, ...dim_n, 4] - auto neighbor_broadcast_shape = - xla::ShapeUtil::MakeShape(data_type, neighbor_broadcast_dims); + auto neighbor_broadcast_shape = xla::ShapeUtil::MakeShape( + data_type, neighbor_broadcast_dims, neighbor_broadcast_expressions); const int64_t last_warp_dim = warp_shape.dims() - 1; diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index db0a73dcc1c152..a0ff04879dabf8 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -322,6 +322,12 @@ static std::vector MakeDynamicDimensions( return output; } +/* static */ absl::StatusOr ShapeUtil::MakeValidatedShape( + PrimitiveType element_type, absl::Span dimensions) { + return MakeValidatedShape(element_type, dimensions, + MakeConstantExpressions(dimensions)); +} + /* static */ absl::StatusOr ShapeUtil::MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions, absl::Span expressions) { diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h index 2ba48309d9976b..9fe3c787a32c02 100644 --- a/third_party/xla/xla/shape_util.h +++ b/third_party/xla/xla/shape_util.h @@ -447,6 +447,10 @@ class ShapeUtil { // shape's size fits in std::numeric_limits::max(), and dynamic size // is not marked static. `dimensions` and `expressions` must have the same // size. + static absl::StatusOr MakeValidatedShape( + PrimitiveType element_type, absl::Span dimensions); + + // As above, but also accepts explicit per-dimension expressions. static absl::StatusOr MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions, absl::Span expressions); @@ -467,6 +471,14 @@ class ShapeUtil { dimensions, expressions); } + // Convenience overload for fully static shapes. Expressions are derived from + // `dimensions`. + template + static Shape MakeShapeWithType(absl::Span dimensions) { + return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), + dimensions); + } + // Constructs a new dense array shape with the given dimensions, // per-dimension expressions, and minor_to_major order in its Layout. // `dimensions` and `expressions` must have the same size. Returns a value From 410cd141fd0e17bd02ada8e5393be12f7a3ed5fc Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Sun, 22 Mar 2026 00:43:24 +0000 Subject: [PATCH 8/8] Document missing symbolic categorical shape --- tensorflow/compiler/tf2xla/kernels/categorical_op.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index e8c804791299a7..f9118e08cb53f7 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -79,6 +79,9 @@ class CategoricalOp : public XlaOpKernel { xla::PrimitiveType uniform_xla_type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(0), &uniform_xla_type)); + // We only have an upper bound and a dynamism bit for num_samples here. + // Once tf2xla can recover a symbolic expression for that scalar input, + // this shape should preserve it instead of defaulting to constants. uniform_shape = xla::ShapeUtil::MakeShape(uniform_xla_type, uniform_shape_array); class_dimension = 2;