diff --git a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc index 4f31c79f91a719..3eef31467c319b 100644 --- a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc @@ -110,11 +110,14 @@ class DenseBincountOp : public XlaOpKernel { scatter_dnums.add_scatter_dims_to_operand_dims(0); if (rank == 2) { - output_shape = xla::ShapeUtil::MakeShape(dtype, {size, output_size}); + output_shape = xla::ShapeUtil::MakeShape( + dtype, {size, output_size}, + {input_shape.expressions(0), xla::DynExpr::_(output_size)}); scatter_dnums.add_inserted_window_dims(1); scatter_dnums.add_scatter_dims_to_operand_dims(1); - auto i_shape = - xla::ShapeUtil::MakeShape(input_xla_type, {input_shape.dimensions()}); + auto i_shape = xla::ShapeUtil::MakeShape(input_xla_type, + input_shape.dimensions(), + input_shape.expressions()); auto i = xla::Iota(ctx->builder(), i_shape, 0); i = xla::Reshape( i, {input_shape.dimensions(0) * input_shape.dimensions(1), 1}, diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index e8c804791299a7..f9118e08cb53f7 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -79,6 +79,9 @@ class CategoricalOp : public XlaOpKernel { xla::PrimitiveType uniform_xla_type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(0), &uniform_xla_type)); + // We only have an upper bound and a dynamism bit for num_samples here. + // Once tf2xla can recover a symbolic expression for that scalar input, + // this shape should preserve it instead of defaulting to constants. uniform_shape = xla::ShapeUtil::MakeShape(uniform_xla_type, uniform_shape_array); class_dimension = 2; diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc index 84c091339c1804..e793dc96870fde 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc @@ -43,6 +43,14 @@ limitations under the License. namespace tensorflow { namespace { +xla::DynExpr* ProductOfExpressions(absl::Span expressions) { + xla::DynExpr* product = xla::DynExpr::one; + for (xla::DynExpr* expr : expressions) { + product = (*product * *expr).s(); + } + return product; +} + class DynamicPartitionOp : public XlaOpKernel { public: explicit DynamicPartitionOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -162,11 +170,13 @@ class DynamicPartitionOp : public XlaOpKernel { int64_t input_count = xla::ShapeUtil::ElementsIn(data_shape); auto data_1d = xla::Reshape(data, {input_count}); auto partitions_1d = xla::Reshape(partitions, {input_count}); - xla::Shape data_1d_shape = - xla::ShapeUtil::MakeShape(data_shape.element_type(), {input_count}); + xla::Shape data_1d_shape = xla::ShapeUtil::MakeShape( + data_shape.element_type(), {input_count}, + {ProductOfExpressions(data_shape.expressions())}); xla::Shape partitions_1d_shape = xla::ShapeUtil::MakeShape( - partition_shape.element_type(), {input_count}); + partition_shape.element_type(), {input_count}, + {ProductOfExpressions(partition_shape.expressions())}); std::vector output, partition_length; std::tie(output, partition_length) = DynamicPartition1D( diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 6674fc6cde0793..ff7cdd2dfa63d4 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -132,13 +132,18 @@ class DynamicStitchOp : public XlaOpKernel { int64_t result_rank = 1 + data0_shape.dims() - indices0_shape.dims(); if (number_of_indices == 0) { std::vector result_shape(result_rank); + std::vector result_expressions(result_rank, + xla::DynExpr::zero); for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { result_shape[d - indices0_shape.dims() + 1] = data0_shape.dim_size(d); + result_expressions[d - indices0_shape.dims() + 1] = + data0_shape.get_expression(d); } xla::PrimitiveType element_type = ctx->input_xla_type(ctx->num_inputs() - 1); xla::Literal empty_literal = xla::Literal::CreateFromShape( - xla::ShapeUtil::MakeShape(element_type, result_shape)); + xla::ShapeUtil::MakeShape(element_type, result_shape, + result_expressions)); ctx->SetOutput(0, xla::ConstantLiteral(ctx->builder(), empty_literal)); return; } diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc index c54c4613d29e44..7ed5644f0fe181 100644 --- a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc @@ -122,12 +122,14 @@ XlaOp ConcatenateIota(xla::XlaBuilder* b, XlaOp indices, for (auto dim : warp_shape) { dimensions.push_back(dim.size); } + auto expressions = warp_shape.get_expressions(); // Except the last dimension, which is of size 1. dimensions.back() = 1; + expressions.back() = xla::DynExpr::one; - auto batch_indices = - xla::Iota(b, xla::ShapeUtil::MakeShape(xla::S32, dimensions), - /*iota_dimension=*/0); + auto batch_indices = xla::Iota( + b, xla::ShapeUtil::MakeShape(xla::S32, dimensions, expressions), + /*iota_dimension=*/0); return xla::ConcatInDim(b, {batch_indices, indices}, dimensions.size() - 1); } @@ -365,14 +367,19 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, auto warp_dims = warp_shape.dim_sizes(); std::vector warp_dims_without_last_dims(warp_dims.begin(), warp_dims.end() - 1); + auto warp_expressions = warp_shape.get_expressions(); + std::vector warp_expressions_without_last_dim( + warp_expressions.begin(), warp_expressions.end() - 1); // With dimension [batch, dim_0, ...dim_n, 4] std::vector neighbor_broadcast_dims = warp_dims_without_last_dims; neighbor_broadcast_dims.push_back(4); + auto neighbor_broadcast_expressions = warp_expressions_without_last_dim; + neighbor_broadcast_expressions.push_back(xla::DynExpr::_(4)); // With dimension [batch, dim_0, ...dim_n, 4] - auto neighbor_broadcast_shape = - xla::ShapeUtil::MakeShape(data_type, neighbor_broadcast_dims); + auto neighbor_broadcast_shape = xla::ShapeUtil::MakeShape( + data_type, neighbor_broadcast_dims, neighbor_broadcast_expressions); const int64_t last_warp_dim = warp_shape.dims() - 1; diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc index 1771181e440f31..9f20fde8a57373 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -241,9 +241,13 @@ absl::Status GetTensorListShapeFromElementTensorListShape( const xla::Shape& shape = xla::ShapeUtil::GetTupleElementShape(element_tensor_list_shape, i); std::vector dimensions = xla::SpanToVector(shape.dimensions()); + std::vector expressions = + xla::SpanToVector(shape.expressions()); dimensions.insert(dimensions.begin(), leading_dim); + expressions.insert(expressions.begin(), xla::DynExpr::_(leading_dim)); shapes.push_back( - xla::ShapeUtil::MakeShape(shape.element_type(), dimensions)); + xla::ShapeUtil::MakeShape(shape.element_type(), dimensions, + expressions)); if (leading_dim_is_dynamic) { shapes.back().set_dynamic_dimension(0, true); } @@ -267,9 +271,13 @@ absl::Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, std::vector shapes; std::vector dimensions = xla::SpanToVector(element_shape.dimensions()); + std::vector expressions = + xla::SpanToVector(element_shape.expressions()); dimensions.insert(dimensions.begin(), leading_dim); + expressions.insert(expressions.begin(), xla::DynExpr::_(leading_dim)); shapes.push_back( - xla::ShapeUtil::MakeShape(element_shape.element_type(), dimensions)); + xla::ShapeUtil::MakeShape(element_shape.element_type(), dimensions, + expressions)); shapes.back().set_dynamic_dimension(0, leading_dim_is_dynamic); shapes.push_back(xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector{})); diff --git a/tensorflow/compiler/tf2xla/kernels/where_op.cc b/tensorflow/compiler/tf2xla/kernels/where_op.cc index 29bcf4fa0769ad..733292744029c0 100644 --- a/tensorflow/compiler/tf2xla/kernels/where_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/where_op.cc @@ -161,8 +161,8 @@ absl::StatusOr CompileWhereWithSort(XlaOpKernelContext* ctx) { XlaOp condition = ctx->Input(0); TF_ASSIGN_OR_RETURN(xla::Shape input_shape, ctx->builder()->GetShape(condition)); - auto iota_shape = - xla::ShapeUtil::MakeShape(xla::S32, input_shape.dimensions()); + auto iota_shape = xla::ShapeUtil::MakeShape(xla::S32, input_shape.dimensions(), + input_shape.expressions()); int64_t flattened_size = xla::Product(iota_shape.dimensions()); xla::DynExpr* flattened_expr = xla::DynExpr::one; @@ -290,7 +290,9 @@ absl::StatusOr CompileWhereWithPrefixSum(XlaOpKernelContext* ctx) { // // and then scatter iotas[out_idxs] into the output. std::vector iotas_to_concat; - auto iota_shape = xla::ShapeUtil::MakeShape(S32, input_shape.dimensions()); + auto iota_shape = + xla::ShapeUtil::MakeShape(S32, input_shape.dimensions(), + input_shape.expressions()); iotas_to_concat.reserve(iota_shape.dimensions_size()); for (int64_t axis = 0; axis < iota_shape.dimensions_size(); ++axis) { iotas_to_concat.push_back( diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index f2dccaea7b1cac..ca265bf39a36f7 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -151,8 +151,15 @@ absl::Status TensorShapeToBoundedXLAShape( } // XLA uses minor-to-major; Tensorflow uses major-to-minor. std::iota(layout.rbegin(), layout.rend(), 0); + std::vector expressions(rank); + for (int d = 0; d < rank; ++d) { + expressions[d] = tensor_shape.dim_size(d) < 0 + ? xla::DynExpr::_(dimensions[d]) + : xla::DynExpr::_(tensor_shape.dim_size(d)); + } xla::Shape result = - xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout); + xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, expressions, + layout); for (int d = 0; d < rank; ++d) { if (tensor_shape.dim_size(d) < 0) { result.set_dynamic_dimension(d, true); @@ -183,10 +190,8 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, } // XLA uses minor-to-major; Tensorflow uses major-to-minor. std::iota(layout.rbegin(), layout.rend(), 0); - xla::Shape result = - xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout); - result.set_expressions(expressions); - return result; + return xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, + expressions, layout); } // Convert a TensorShape into the equivalent XLA Shape proto. @@ -225,10 +230,8 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, // XLA uses minor-to-major; Tensorflow uses major-to-minor. std::iota(layout.rbegin(), layout.rend(), 0); - auto shape = - xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout); - shape.set_expressions(expressions); - return shape; + return xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, + expressions, layout); } absl::StatusOr> GetShapeLayoutVector(const xla::Shape& shape) { diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc index 8fe6537b5168a8..a123d94d5edbac 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc @@ -214,11 +214,14 @@ Shape GetPerDeviceShape(const Shape& shape, const HloSharding& sharding, for (int64_t i = 0; i < limit.size(); ++i) { dimensions[i] = limit[i] - offset[i]; } + auto expressions = xla::ShapeUtil::MakeConstantExpressions(dimensions); if (shape.has_layout()) { return xla::ShapeUtil::MakeShapeWithDenseLayout( - shape.element_type(), dimensions, shape.layout().minor_to_major()); + shape.element_type(), dimensions, expressions, + shape.layout().minor_to_major()); } - return xla::ShapeUtil::MakeShape(shape.element_type(), dimensions); + return xla::ShapeUtil::MakeShape(shape.element_type(), dimensions, + expressions); } absl::Status AddVariableUpdatesToCores( diff --git a/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer.cc b/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer.cc index b3c2ffe79ec00c..fc02b7daeff488 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer.cc @@ -121,6 +121,7 @@ HloInstruction* ApplyUnaries(HloInstruction* instr, instr = instr->AddInstruction(unary->CloneWithNewOperands( ShapeUtil::MakeShapeWithDenseLayout( instr->shape().element_type(), unary->shape().dimensions(), + unary->shape().get_expressions(), unary->shape().layout().minor_to_major()), {instr})); } diff --git a/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.cc b/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.cc index 2fe502429287b4..7d1f27aa59ae9f 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.cc @@ -47,6 +47,7 @@ class VariadicReductionLayoutEqualizer : public DfsHloRewriteVisitor { if (first_input_s.layout() != input_s.layout()) { Shape new_input_s = ShapeUtil::MakeShapeWithDenseLayout( input_s.element_type(), input_s.dimensions(), + input_s.get_expressions(), first_input_s.layout().minor_to_major()); auto copy = MakeCopyHlo(input, new_input_s); changed = true; diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc index 9369b89d77ec61..ead92527f31451 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc @@ -130,8 +130,9 @@ Shape TypeToShape(mlir::Type type) { llvm::SmallVector dimensions(m.getShape().begin(), m.getShape().end()); + auto expressions = ::xla::ShapeUtil::MakeConstantExpressions(dimensions); return ::xla::ShapeUtil::MakeShapeWithDenseLayout( - primitive_type, dimensions, minor_to_major); + primitive_type, dimensions, expressions, minor_to_major); } else if (auto t = mlir::dyn_cast(type)) { // TODO(jpienaar): This is only handling the base case with primitive // element type. @@ -188,7 +189,7 @@ Shape TypeToShape(mlir::Type type) { auto final_ordering = mlir::applyPermutationMap( dimToLvl, llvm::ArrayRef(ordering)); auto sparse_shape = ::xla::ShapeUtil::MakeShapeWithSparseLayout( - primitive_type, shape, final_ordering); + primitive_type, shape, expressions, final_ordering); return sparse_shape; } diff --git a/third_party/xla/xla/literal.cc b/third_party/xla/xla/literal.cc index d91cfe064922c6..b505e68562257f 100644 --- a/third_party/xla/xla/literal.cc +++ b/third_party/xla/xla/literal.cc @@ -1302,8 +1302,7 @@ Literal LiteralBase::Slice(absl::Span start_indices, result_dimensions.push_back(dimension); } auto result_shape = ShapeUtil::MakeShapeWithDenseLayout( - shape().element_type(), result_dimensions, - LayoutUtil::MinorToMajor(shape())); + shape().element_type(), result_dimensions, LayoutUtil::MinorToMajor(shape())); ShapeUtil::CopyDynamicDimensions(&result_shape, shape()); Literal result_literal(result_shape); primitive_util::ArrayTypeSwitch( diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc index 5054a440778105..fdd1a6dc0da03a 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc @@ -1316,6 +1316,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { x = instr->AddInstruction(op.first->CloneWithNewOperands( ShapeUtil::MakeShapeWithDenseLayout( x->shape().element_type(), op.first->shape().dimensions(), + op.first->shape().get_expressions(), op.first->shape().layout().minor_to_major()), operands)); } @@ -1378,6 +1379,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { instr->AddInstruction(HloInstruction::CreateCustomCall( ShapeUtil::MakeShapeWithDenseLayout( instr->shape().element_type(), new_output_shape.dimensions(), + new_output_shape.get_expressions(), instr->shape().layout().minor_to_major()), operands_list, kCublasLtMatmulF8CallTarget)); TF_RETURN_IF_ERROR(new_custom_call->set_backend_config(gpu_backend_config)); diff --git a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc index ce454624144803..5021ded8d2b5d1 100644 --- a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc +++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc @@ -183,11 +183,13 @@ absl::StatusOr ShiftDequantizationF8( for (HloInstruction* unary : unaries[k]) { Shape new_shape = ShapeUtil::MakeShapeWithDenseLayout( operands[k]->shape().element_type(), unary->shape().dimensions(), + unary->shape().get_expressions(), unary->shape().layout().minor_to_major()); operands[k] = unary->AddInstruction(unary->CloneWithNewOperands( ShapeUtil::MakeShapeWithDenseLayout( operands[k]->shape().element_type(), unary->shape().dimensions(), + unary->shape().get_expressions(), unary->shape().layout().minor_to_major()), {operands[k]})); } diff --git a/third_party/xla/xla/service/layout_assignment.cc b/third_party/xla/xla/service/layout_assignment.cc index b5adc212b53a17..e5fd82c6f599cf 100644 --- a/third_party/xla/xla/service/layout_assignment.cc +++ b/third_party/xla/xla/service/layout_assignment.cc @@ -1401,6 +1401,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( const Shape& output_shape = instruction->shape(); Shape output_shape_with_layout = ShapeUtil::MakeShapeWithDenseLayout( output_shape.element_type(), output_shape.dimensions(), + output_shape.get_expressions(), LayoutUtil::MinorToMajor(output_layout)); Shape operand_shape = operand->shape(); *operand_shape.mutable_layout() = @@ -1539,6 +1540,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( } Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithDenseLayout( operand->shape().element_type(), operand->shape().dimensions(), + operand->shape().get_expressions(), LayoutUtil::MinorToMajor(operand_layout)); Shape output_shape = user->shape(); *output_shape.mutable_layout() = diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.h b/third_party/xla/xla/service/llvm_ir/ir_array.h index f5b2b7c4fbd792..1e75f2a0e938fd 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_array.h +++ b/third_party/xla/xla/service/llvm_ir/ir_array.h @@ -140,8 +140,8 @@ class IrArray { } Shape AsShapeWithType(PrimitiveType element_type) const { - return ShapeUtil::MakeShapeWithDenseLayout(element_type, dims_, - layout_.minor_to_major()); + return ShapeUtil::MakeShapeWithDenseLayout( + element_type, dims_, layout_.minor_to_major()); } // Given that "this" is the target index of a reshape from `input_shape` diff --git a/third_party/xla/xla/shape_test.cc b/third_party/xla/xla/shape_test.cc index fb51518b116eb2..fd92cb7a9c2735 100644 --- a/third_party/xla/xla/shape_test.cc +++ b/third_party/xla/xla/shape_test.cc @@ -30,6 +30,15 @@ limitations under the License. namespace xla { namespace { +std::vector MakeExprs(absl::Span dimensions) { + std::vector expressions; + expressions.reserve(dimensions.size()); + for (int64_t d : dimensions) { + expressions.push_back(DynExpr::_(d)); + } + return expressions; +} + class ShapeTest : public ::testing::Test { protected: const Shape opaque_ = ShapeUtil::MakeOpaqueShape(); @@ -40,7 +49,8 @@ class ShapeTest : public ::testing::Test { const Shape matrix_ = ShapeUtil::MakeShape(U32, {1, 2}); const Shape matrix2_ = ShapeUtil::MakeShapeWithDenseLayout(S32, {3, 4}, {0, 1}); - const Shape matrix_buffer_ = ShapeUtil::MakeBufferShape(S32, {3, 4}); + const Shape matrix_buffer_ = + ShapeUtil::MakeBufferShape(S32, {3, 4}); const Shape tuple_ = ShapeUtil::MakeTupleShape({opaque_, scalar_, matrix_, matrix2_}); const Shape nested_tuple_ = diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index 2a72f66d943dfd..a0ff04879dabf8 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -123,8 +123,9 @@ void PrintBufferShape(Printer* printer, const Shape& shape) { // its Layout. absl::StatusOr MakeShapeWithLayoutInternal( PrimitiveType element_type, absl::Span dimensions, - absl::Span minor_to_major, - absl::Span tiles, int64_t tail_padding_alignment_in_elements, + absl::Span expressions, + absl::Span minor_to_major, absl::Span tiles, + int64_t tail_padding_alignment_in_elements, PrimitiveType index_primitive_type, PrimitiveType pointer_primitive_type, int64_t element_size_in_bits, int64_t memory_space, absl::Span split_configs, @@ -138,8 +139,8 @@ absl::StatusOr MakeShapeWithLayoutInternal( return InvalidArgument("Unsupported element type: %s", PrimitiveType_Name(element_type)); } - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeUtil::MakeValidatedShape(element_type, dimensions)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::MakeValidatedShape( + element_type, dimensions, expressions)); if (element_size_in_bits == ShapeUtil::ByteSizeOfPrimitiveType(element_type) * 8) { // Only set element_size_in_bits if it's different from the default value. @@ -267,7 +268,7 @@ static std::vector MakeDynamicDimensions( return dynamic_dimensions; } -static std::vector MakeExpressions( +/* static */ std::vector ShapeUtil::MakeConstantExpressions( absl::Span dimensions) { std::vector expressions; expressions.reserve(dimensions.size()); @@ -283,8 +284,13 @@ static std::vector MakeExpressions( return MakeValidatedShape(element_type, dimensions, expressions).value(); } +/* static */ Shape ShapeUtil::MakeShape( + PrimitiveType element_type, absl::Span dimensions) { + return MakeShape(element_type, dimensions, MakeConstantExpressions(dimensions)); +} + /* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) { - return MakeShape(element_type, {}); + return MakeShape(element_type, {}, {}); } /* static */ Shape ShapeUtil::MakeShape( @@ -296,9 +302,17 @@ static std::vector MakeExpressions( .value(); } +/* static */ Shape ShapeUtil::MakeBufferShape( + PrimitiveType element_type, absl::Span dimensions, + absl::Span expressions) { + return Shape::MakeBufferShape( + MakeShape(element_type, dimensions, expressions)); +} + /* static */ Shape ShapeUtil::MakeBufferShape( PrimitiveType element_type, absl::Span dimensions) { - return Shape::MakeBufferShape(MakeShape(element_type, dimensions)); + return MakeBufferShape(element_type, dimensions, + MakeConstantExpressions(dimensions)); } /* static */ Shape ShapeUtil::MakeShapeWithStaticDimensions( @@ -308,12 +322,17 @@ static std::vector MakeExpressions( return output; } +/* static */ absl::StatusOr ShapeUtil::MakeValidatedShape( + PrimitiveType element_type, absl::Span dimensions) { + return MakeValidatedShape(element_type, dimensions, + MakeConstantExpressions(dimensions)); +} + /* static */ absl::StatusOr ShapeUtil::MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions, absl::Span expressions) { - return MakeValidatedShape( - element_type, dimensions, MakeDynamicDimensions(dimensions), - expressions.empty() ? MakeExpressions(dimensions) : expressions); + return MakeValidatedShape(element_type, dimensions, + MakeDynamicDimensions(dimensions), expressions); } /* static */ absl::StatusOr ShapeUtil::MakeValidatedShape( @@ -373,11 +392,12 @@ static std::vector MakeExpressions( /* static */ Shape ShapeUtil::MakeShapeWithDenseLayout( PrimitiveType element_type, absl::Span dimensions, + absl::Span expressions, absl::Span minor_to_major, absl::Span tiles, int64_t tail_padding_alignment_in_elements, int64_t element_size_in_bits, int64_t memory_space, absl::Span split_configs) { auto ret = MakeShapeWithLayoutInternal( - element_type, dimensions, minor_to_major, tiles, + element_type, dimensions, expressions, minor_to_major, tiles, tail_padding_alignment_in_elements, /*index_primitive_type=*/PRIMITIVE_TYPE_INVALID, /*pointer_primitive_type=*/PRIMITIVE_TYPE_INVALID, element_size_in_bits, @@ -387,14 +407,26 @@ static std::vector MakeExpressions( return *ret; } +/* static */ Shape ShapeUtil::MakeShapeWithDenseLayout( + PrimitiveType element_type, absl::Span dimensions, + absl::Span minor_to_major, absl::Span tiles, + int64_t tail_padding_alignment_in_elements, int64_t element_size_in_bits, + int64_t memory_space, absl::Span split_configs) { + return MakeShapeWithDenseLayout( + element_type, dimensions, MakeConstantExpressions(dimensions), + minor_to_major, tiles, tail_padding_alignment_in_elements, + element_size_in_bits, memory_space, split_configs); +} + /* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( PrimitiveType element_type, absl::Span dimensions, + absl::Span expressions, absl::Span minor_to_major, PrimitiveType index_primitive_type, PrimitiveType pointer_primitive_type, int64_t tail_padding_alignment_in_elements, int64_t element_size_in_bits, int64_t memory_space, std::optional physical_shape) { auto ret = MakeShapeWithLayoutInternal( - element_type, dimensions, minor_to_major, + element_type, dimensions, expressions, minor_to_major, /*tiles=*/{}, tail_padding_alignment_in_elements, index_primitive_type, pointer_primitive_type, element_size_in_bits, memory_space, /*split_configs=*/{}, std::move(physical_shape)); @@ -433,10 +465,8 @@ static std::vector MakeExpressions( absl::Span expressions) { std::vector layout(dimensions.size()); std::iota(layout.rbegin(), layout.rend(), static_cast(0)); - auto shape = MakeShapeWithDenseLayout(element_type, dimensions, layout); - std::vector exprs(expressions.begin(), expressions.end()); - shape.set_expressions(exprs); - return shape; + return MakeShapeWithDenseLayout(element_type, dimensions, expressions, + layout); } /* static */ Shape @@ -450,7 +480,16 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } dims[i] = shape.dimensions(dim); } - Shape new_shape = MakeShapeWithDescendingLayout(shape.element_type(), dims); + std::vector expressions(shape.dimensions().size()); + for (int i = 0; i < shape.dimensions().size(); ++i) { + int dim = i; + if (shape.has_layout()) { + dim = LayoutUtil::Major(shape.layout(), dim); + } + expressions[i] = shape.expressions(dim); + } + Shape new_shape = + MakeShapeWithDescendingLayout(shape.element_type(), dims, expressions); // Since the physical layout is kept the same, the tiles and element size are // the same also. if (shape.has_layout()) { @@ -1826,7 +1865,8 @@ ShapeUtil::DecomposeBitcastToTrt(const Shape& input_shape, } Shape output_shape_with_layout = MakeShapeWithDenseLayout( - output_shape.element_type(), output_shape.dimensions(), output_layout); + output_shape.element_type(), output_shape.dimensions(), + output_shape.get_expressions(), output_layout); CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout)) << "reshape is not a bitcast for input_shape: " << ShapeUtil::HumanStringWithLayout(input_shape) @@ -2059,7 +2099,8 @@ struct ParallelState { } // Create the shape of the "work" which has same layout as the original shape. - Shape work_shape = ShapeUtil::MakeShape(shape.element_type(), work_dims); + Shape work_shape = ShapeUtil::MakeShape(shape.element_type(), work_dims, + shape.get_expressions()); *work_shape.mutable_layout() = shape.layout(); // We target one task (partition) per available thread. diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h index 7e9ed98719ff4f..9fe3c787a32c02 100644 --- a/third_party/xla/xla/shape_util.h +++ b/third_party/xla/xla/shape_util.h @@ -407,11 +407,16 @@ class ShapeUtil { return shape.element_type() != PRIMITIVE_TYPE_INVALID; } - // Constructs a new shape with the given element type and sequence of - // dimensions. + // Constructs a new shape with the given element type, dimensions, and + // per-dimension expressions. `dimensions` and `expressions` must have the + // same size. static Shape MakeShape(PrimitiveType element_type, absl::Span dimensions, - absl::Span expressions = {}); + absl::Span expressions); + // Convenience overload for fully static shapes. Expressions are derived from + // `dimensions`. + static Shape MakeShape(PrimitiveType element_type, + absl::Span dimensions); // Make a scalar shape with given primitive type. static Shape MakeScalarShape(PrimitiveType element_type); @@ -426,47 +431,84 @@ class ShapeUtil { absl::Span dimensions, const std::vector& dynamic_dimensions, absl::Span expressions); - // Constructs a new buffer shape with the given element type, and sequence of - // dimensions. + // Constructs a new buffer shape with the given element type, dimensions, and + // per-dimension expressions. `dimensions` and `expressions` must have the + // same size. + static Shape MakeBufferShape(PrimitiveType element_type, + absl::Span dimensions, + absl::Span expressions); + // Convenience overload for fully static buffer shapes. Expressions are + // derived from `dimensions`. static Shape MakeBufferShape(PrimitiveType element_type, absl::Span dimensions); - // Constructs a new shape with the given element type and sequence of - // dimensions. Method checks if the element type is valid, the shape's - // size fits in std::numeric_limits::max(), and dynamic size is not - // marked static. + // Constructs a new shape with the given element type, dimensions, and + // per-dimension expressions. Method checks if the element type is valid, the + // shape's size fits in std::numeric_limits::max(), and dynamic size + // is not marked static. `dimensions` and `expressions` must have the same + // size. + static absl::StatusOr MakeValidatedShape( + PrimitiveType element_type, absl::Span dimensions); + + // As above, but also accepts explicit per-dimension expressions. static absl::StatusOr MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions, - absl::Span expressions = {}); + absl::Span expressions); + // As above, but also accepts an explicit dynamic-dimension mask. All three + // inputs must have the same size. static absl::StatusOr MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions, const std::vector& dynamic_dimensions, - absl::Span expressions = {}); + absl::Span expressions); + + // Creates a Shape with element type corresponding to T, the given + // dimensions, and the given per-dimension expressions. + template + static Shape MakeShapeWithType(absl::Span dimensions, + absl::Span expressions) { + return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), + dimensions, expressions); + } - // Creates a Shape with element type corresponding to T and the given - // dimensions + // Convenience overload for fully static shapes. Expressions are derived from + // `dimensions`. template static Shape MakeShapeWithType(absl::Span dimensions) { return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), dimensions); } - // Constructs a new dense array shape with the given minor_to_major order in - // its Layout. Returns a value shape such that shape.has_layout(). + // Constructs a new dense array shape with the given dimensions, + // per-dimension expressions, and minor_to_major order in its Layout. + // `dimensions` and `expressions` must have the same size. Returns a value + // shape such that shape.has_layout(). static Shape MakeShapeWithDenseLayout( PrimitiveType element_type, absl::Span dimensions, + absl::Span expressions, absl::Span minor_to_major, absl::Span tiles = {}, int64_t tail_padding_alignment_in_elements = 1, int64_t element_size_in_bits = 0, int64_t memory_space = 0, absl::Span split_configs = {}); - // Constructs a new sparse array shape with the given minor_to_major order - // in its Layout. Returns a value shape such that - // shape.has_layout(). + // Convenience overload for the common static case where expressions mirror + // the provided dimensions. + static Shape MakeShapeWithDenseLayout( + PrimitiveType element_type, absl::Span dimensions, + absl::Span minor_to_major, + absl::Span tiles = {}, + int64_t tail_padding_alignment_in_elements = 1, + int64_t element_size_in_bits = 0, int64_t memory_space = 0, + absl::Span split_configs = {}); + + // Constructs a new sparse array shape with the given dimensions, + // per-dimension expressions, and minor_to_major order in its Layout. + // `dimensions` and `expressions` must have the same size. Returns a value + // shape such that shape.has_layout(). static Shape MakeShapeWithSparseLayout( PrimitiveType element_type, absl::Span dimensions, + absl::Span expressions, absl::Span minor_to_major, PrimitiveType index_primitive_type = PRIMITIVE_TYPE_INVALID, PrimitiveType pointer_primitive_type = PRIMITIVE_TYPE_INVALID, @@ -483,10 +525,11 @@ class ShapeUtil { // Returns the same shape except with all dimensions set to be static. static Shape MakeShapeWithStaticDimensions(const Shape& shape); - // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}). + // Constructs a new shape with the given dimensions, per-dimension + // expressions, and major-first layout (i.e. {n, n-1, ..., 0}). static Shape MakeShapeWithDescendingLayout( PrimitiveType element_type, absl::Span dimensions, - absl::Span expressions = {}); + absl::Span expressions); // Returns a new Shape based on the given Shape with low-dimension-major // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions @@ -525,6 +568,11 @@ class ShapeUtil { // Returns whether the element type of the shape is complex. static bool ElementIsComplex(const Shape& shape); + // Builds one constant expression per dimension, mirroring the supplied + // dimension sizes. + static std::vector MakeConstantExpressions( + absl::Span dimensions); + // Returns whether the element type has the given bit width. static bool ElementHasBitWidth(const Shape& shape, int bits);