Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions tensorflow/compiler/tf2xla/kernels/bincount_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,14 @@ class DenseBincountOp : public XlaOpKernel {
scatter_dnums.add_scatter_dims_to_operand_dims(0);

if (rank == 2) {
output_shape = xla::ShapeUtil::MakeShape(dtype, {size, output_size});
output_shape = xla::ShapeUtil::MakeShape(
dtype, {size, output_size},
{input_shape.expressions(0), xla::DynExpr::_(output_size)});
scatter_dnums.add_inserted_window_dims(1);
scatter_dnums.add_scatter_dims_to_operand_dims(1);
auto i_shape =
xla::ShapeUtil::MakeShape(input_xla_type, {input_shape.dimensions()});
auto i_shape = xla::ShapeUtil::MakeShape(input_xla_type,
input_shape.dimensions(),
input_shape.expressions());
auto i = xla::Iota(ctx->builder(), i_shape, 0);
i = xla::Reshape(
i, {input_shape.dimensions(0) * input_shape.dimensions(1), 1},
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/compiler/tf2xla/kernels/categorical_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ class CategoricalOp : public XlaOpKernel {
xla::PrimitiveType uniform_xla_type;
OP_REQUIRES_OK(ctx,
DataTypeToPrimitiveType(input_type(0), &uniform_xla_type));
// We only have an upper bound and a dynamism bit for num_samples here.
// Once tf2xla can recover a symbolic expression for that scalar input,
// this shape should preserve it instead of defaulting to constants.
uniform_shape =
xla::ShapeUtil::MakeShape(uniform_xla_type, uniform_shape_array);
class_dimension = 2;
Expand Down
16 changes: 13 additions & 3 deletions tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ limitations under the License.
namespace tensorflow {
namespace {

xla::DynExpr* ProductOfExpressions(absl::Span<xla::DynExpr* const> expressions) {
xla::DynExpr* product = xla::DynExpr::one;
for (xla::DynExpr* expr : expressions) {
product = (*product * *expr).s();
}
return product;
}

class DynamicPartitionOp : public XlaOpKernel {
public:
explicit DynamicPartitionOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
Expand Down Expand Up @@ -162,11 +170,13 @@ class DynamicPartitionOp : public XlaOpKernel {
int64_t input_count = xla::ShapeUtil::ElementsIn(data_shape);
auto data_1d = xla::Reshape(data, {input_count});
auto partitions_1d = xla::Reshape(partitions, {input_count});
xla::Shape data_1d_shape =
xla::ShapeUtil::MakeShape(data_shape.element_type(), {input_count});
xla::Shape data_1d_shape = xla::ShapeUtil::MakeShape(
data_shape.element_type(), {input_count},
{ProductOfExpressions(data_shape.expressions())});

xla::Shape partitions_1d_shape = xla::ShapeUtil::MakeShape(
partition_shape.element_type(), {input_count});
partition_shape.element_type(), {input_count},
{ProductOfExpressions(partition_shape.expressions())});

std::vector<xla::XlaOp> output, partition_length;
std::tie(output, partition_length) = DynamicPartition1D(
Expand Down
7 changes: 6 additions & 1 deletion tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,18 @@ class DynamicStitchOp : public XlaOpKernel {
int64_t result_rank = 1 + data0_shape.dims() - indices0_shape.dims();
if (number_of_indices == 0) {
std::vector<int64_t> result_shape(result_rank);
std::vector<xla::DynExpr*> result_expressions(result_rank,
xla::DynExpr::zero);
for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) {
result_shape[d - indices0_shape.dims() + 1] = data0_shape.dim_size(d);
result_expressions[d - indices0_shape.dims() + 1] =
data0_shape.get_expression(d);
}
xla::PrimitiveType element_type =
ctx->input_xla_type(ctx->num_inputs() - 1);
xla::Literal empty_literal = xla::Literal::CreateFromShape(
xla::ShapeUtil::MakeShape(element_type, result_shape));
xla::ShapeUtil::MakeShape(element_type, result_shape,
result_expressions));
ctx->SetOutput(0, xla::ConstantLiteral(ctx->builder(), empty_literal));
return;
}
Expand Down
17 changes: 12 additions & 5 deletions tensorflow/compiler/tf2xla/kernels/resampler_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,14 @@ XlaOp ConcatenateIota(xla::XlaBuilder* b, XlaOp indices,
for (auto dim : warp_shape) {
dimensions.push_back(dim.size);
}
auto expressions = warp_shape.get_expressions();
// Except the last dimension, which is of size 1.
dimensions.back() = 1;
expressions.back() = xla::DynExpr::one;

auto batch_indices =
xla::Iota(b, xla::ShapeUtil::MakeShape(xla::S32, dimensions),
/*iota_dimension=*/0);
auto batch_indices = xla::Iota(
b, xla::ShapeUtil::MakeShape(xla::S32, dimensions, expressions),
/*iota_dimension=*/0);

return xla::ConcatInDim(b, {batch_indices, indices}, dimensions.size() - 1);
}
Expand Down Expand Up @@ -365,14 +367,19 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio,
auto warp_dims = warp_shape.dim_sizes();
std::vector<int64_t> warp_dims_without_last_dims(warp_dims.begin(),
warp_dims.end() - 1);
auto warp_expressions = warp_shape.get_expressions();
std::vector<xla::DynExpr*> warp_expressions_without_last_dim(
warp_expressions.begin(), warp_expressions.end() - 1);

// With dimension [batch, dim_0, ...dim_n, 4]
std::vector<int64_t> neighbor_broadcast_dims = warp_dims_without_last_dims;
neighbor_broadcast_dims.push_back(4);
auto neighbor_broadcast_expressions = warp_expressions_without_last_dim;
neighbor_broadcast_expressions.push_back(xla::DynExpr::_(4));

// With dimension [batch, dim_0, ...dim_n, 4]
auto neighbor_broadcast_shape =
xla::ShapeUtil::MakeShape(data_type, neighbor_broadcast_dims);
auto neighbor_broadcast_shape = xla::ShapeUtil::MakeShape(
data_type, neighbor_broadcast_dims, neighbor_broadcast_expressions);

const int64_t last_warp_dim = warp_shape.dims() - 1;

Expand Down
12 changes: 10 additions & 2 deletions tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,13 @@ absl::Status GetTensorListShapeFromElementTensorListShape(
const xla::Shape& shape =
xla::ShapeUtil::GetTupleElementShape(element_tensor_list_shape, i);
std::vector<int64_t> dimensions = xla::SpanToVector(shape.dimensions());
std::vector<xla::DynExpr*> expressions =
xla::SpanToVector(shape.expressions());
dimensions.insert(dimensions.begin(), leading_dim);
expressions.insert(expressions.begin(), xla::DynExpr::_(leading_dim));
shapes.push_back(
xla::ShapeUtil::MakeShape(shape.element_type(), dimensions));
xla::ShapeUtil::MakeShape(shape.element_type(), dimensions,
expressions));
if (leading_dim_is_dynamic) {
shapes.back().set_dynamic_dimension(0, true);
}
Expand All @@ -267,9 +271,13 @@ absl::Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
std::vector<xla::Shape> shapes;
std::vector<int64_t> dimensions =
xla::SpanToVector(element_shape.dimensions());
std::vector<xla::DynExpr*> expressions =
xla::SpanToVector(element_shape.expressions());
dimensions.insert(dimensions.begin(), leading_dim);
expressions.insert(expressions.begin(), xla::DynExpr::_(leading_dim));
shapes.push_back(
xla::ShapeUtil::MakeShape(element_shape.element_type(), dimensions));
xla::ShapeUtil::MakeShape(element_shape.element_type(), dimensions,
expressions));
shapes.back().set_dynamic_dimension(0, leading_dim_is_dynamic);
shapes.push_back(xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32,
std::vector<int64_t>{}));
Expand Down
8 changes: 5 additions & 3 deletions tensorflow/compiler/tf2xla/kernels/where_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ absl::StatusOr<XlaOp> CompileWhereWithSort(XlaOpKernelContext* ctx) {
XlaOp condition = ctx->Input(0);
TF_ASSIGN_OR_RETURN(xla::Shape input_shape,
ctx->builder()->GetShape(condition));
auto iota_shape =
xla::ShapeUtil::MakeShape(xla::S32, input_shape.dimensions());
auto iota_shape = xla::ShapeUtil::MakeShape(xla::S32, input_shape.dimensions(),
input_shape.expressions());

int64_t flattened_size = xla::Product(iota_shape.dimensions());
xla::DynExpr* flattened_expr = xla::DynExpr::one;
Expand Down Expand Up @@ -290,7 +290,9 @@ absl::StatusOr<XlaOp> CompileWhereWithPrefixSum(XlaOpKernelContext* ctx) {
//
// and then scatter iotas[out_idxs] into the output.
std::vector<XlaOp> iotas_to_concat;
auto iota_shape = xla::ShapeUtil::MakeShape(S32, input_shape.dimensions());
auto iota_shape =
xla::ShapeUtil::MakeShape(S32, input_shape.dimensions(),
input_shape.expressions());
iotas_to_concat.reserve(iota_shape.dimensions_size());
for (int64_t axis = 0; axis < iota_shape.dimensions_size(); ++axis) {
iotas_to_concat.push_back(
Expand Down
21 changes: 12 additions & 9 deletions tensorflow/compiler/tf2xla/shape_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,15 @@ absl::Status TensorShapeToBoundedXLAShape(
}
// XLA uses minor-to-major; Tensorflow uses major-to-minor.
std::iota(layout.rbegin(), layout.rend(), 0);
std::vector<xla::DynExpr*> expressions(rank);
for (int d = 0; d < rank; ++d) {
expressions[d] = tensor_shape.dim_size(d) < 0
? xla::DynExpr::_(dimensions[d])
: xla::DynExpr::_(tensor_shape.dim_size(d));
}
xla::Shape result =
xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout);
xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, expressions,
layout);
for (int d = 0; d < rank; ++d) {
if (tensor_shape.dim_size(d) < 0) {
result.set_dynamic_dimension(d, true);
Expand Down Expand Up @@ -183,10 +190,8 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
}
// XLA uses minor-to-major; Tensorflow uses major-to-minor.
std::iota(layout.rbegin(), layout.rend(), 0);
xla::Shape result =
xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout);
result.set_expressions(expressions);
return result;
return xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions,
expressions, layout);
}

// Convert a TensorShape into the equivalent XLA Shape proto.
Expand Down Expand Up @@ -225,10 +230,8 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
// XLA uses minor-to-major; Tensorflow uses major-to-minor.
std::iota(layout.rbegin(), layout.rend(), 0);

auto shape =
xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout);
shape.set_expressions(expressions);
return shape;
return xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions,
expressions, layout);
}

absl::StatusOr<std::vector<int>> GetShapeLayoutVector(const xla::Shape& shape) {
Expand Down
7 changes: 5 additions & 2 deletions tensorflow/core/tpu/kernels/tpu_compile_op_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,14 @@ Shape GetPerDeviceShape(const Shape& shape, const HloSharding& sharding,
for (int64_t i = 0; i < limit.size(); ++i) {
dimensions[i] = limit[i] - offset[i];
}
auto expressions = xla::ShapeUtil::MakeConstantExpressions(dimensions);
if (shape.has_layout()) {
return xla::ShapeUtil::MakeShapeWithDenseLayout(
shape.element_type(), dimensions, shape.layout().minor_to_major());
shape.element_type(), dimensions, expressions,
shape.layout().minor_to_major());
}
return xla::ShapeUtil::MakeShape(shape.element_type(), dimensions);
return xla::ShapeUtil::MakeShape(shape.element_type(), dimensions,
expressions);
}

absl::Status AddVariableUpdatesToCores(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ HloInstruction* ApplyUnaries(HloInstruction* instr,
instr = instr->AddInstruction(unary->CloneWithNewOperands(
ShapeUtil::MakeShapeWithDenseLayout(
instr->shape().element_type(), unary->shape().dimensions(),
unary->shape().get_expressions(),
unary->shape().layout().minor_to_major()),
{instr}));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class VariadicReductionLayoutEqualizer : public DfsHloRewriteVisitor {
if (first_input_s.layout() != input_s.layout()) {
Shape new_input_s = ShapeUtil::MakeShapeWithDenseLayout(
input_s.element_type(), input_s.dimensions(),
input_s.get_expressions(),
first_input_s.layout().minor_to_major());
auto copy = MakeCopyHlo(input, new_input_s);
changed = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ Shape TypeToShape(mlir::Type type) {

llvm::SmallVector<int64_t, 4> dimensions(m.getShape().begin(),
m.getShape().end());
auto expressions = ::xla::ShapeUtil::MakeConstantExpressions(dimensions);
return ::xla::ShapeUtil::MakeShapeWithDenseLayout(
primitive_type, dimensions, minor_to_major);
primitive_type, dimensions, expressions, minor_to_major);
} else if (auto t = mlir::dyn_cast<mlir::RankedTensorType>(type)) {
// TODO(jpienaar): This is only handling the base case with primitive
// element type.
Expand Down Expand Up @@ -188,7 +189,7 @@ Shape TypeToShape(mlir::Type type) {
auto final_ordering = mlir::applyPermutationMap(
dimToLvl, llvm::ArrayRef<int64_t>(ordering));
auto sparse_shape = ::xla::ShapeUtil::MakeShapeWithSparseLayout(
primitive_type, shape, final_ordering);
primitive_type, shape, expressions, final_ordering);
return sparse_shape;
}

Expand Down
3 changes: 1 addition & 2 deletions third_party/xla/xla/literal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1302,8 +1302,7 @@ Literal LiteralBase::Slice(absl::Span<const int64_t> start_indices,
result_dimensions.push_back(dimension);
}
auto result_shape = ShapeUtil::MakeShapeWithDenseLayout(
shape().element_type(), result_dimensions,
LayoutUtil::MinorToMajor(shape()));
shape().element_type(), result_dimensions, LayoutUtil::MinorToMajor(shape()));
ShapeUtil::CopyDynamicDimensions(&result_shape, shape());
Literal result_literal(result_shape);
primitive_util::ArrayTypeSwitch(
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
x = instr->AddInstruction(op.first->CloneWithNewOperands(
ShapeUtil::MakeShapeWithDenseLayout(
x->shape().element_type(), op.first->shape().dimensions(),
op.first->shape().get_expressions(),
op.first->shape().layout().minor_to_major()),
operands));
}
Expand Down Expand Up @@ -1378,6 +1379,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
instr->AddInstruction(HloInstruction::CreateCustomCall(
ShapeUtil::MakeShapeWithDenseLayout(
instr->shape().element_type(), new_output_shape.dimensions(),
new_output_shape.get_expressions(),
instr->shape().layout().minor_to_major()),
operands_list, kCublasLtMatmulF8CallTarget));
TF_RETURN_IF_ERROR(new_custom_call->set_backend_config(gpu_backend_config));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,13 @@ absl::StatusOr<HloInstruction*> ShiftDequantizationF8(
for (HloInstruction* unary : unaries[k]) {
Shape new_shape = ShapeUtil::MakeShapeWithDenseLayout(
operands[k]->shape().element_type(), unary->shape().dimensions(),
unary->shape().get_expressions(),
unary->shape().layout().minor_to_major());

operands[k] = unary->AddInstruction(unary->CloneWithNewOperands(
ShapeUtil::MakeShapeWithDenseLayout(
operands[k]->shape().element_type(), unary->shape().dimensions(),
unary->shape().get_expressions(),
unary->shape().layout().minor_to_major()),
{operands[k]}));
}
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/layout_assignment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1401,6 +1401,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
const Shape& output_shape = instruction->shape();
Shape output_shape_with_layout = ShapeUtil::MakeShapeWithDenseLayout(
output_shape.element_type(), output_shape.dimensions(),
output_shape.get_expressions(),
LayoutUtil::MinorToMajor(output_layout));
Shape operand_shape = operand->shape();
*operand_shape.mutable_layout() =
Expand Down Expand Up @@ -1539,6 +1540,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
}
Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithDenseLayout(
operand->shape().element_type(), operand->shape().dimensions(),
operand->shape().get_expressions(),
LayoutUtil::MinorToMajor(operand_layout));
Shape output_shape = user->shape();
*output_shape.mutable_layout() =
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/llvm_ir/ir_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ class IrArray {
}

Shape AsShapeWithType(PrimitiveType element_type) const {
return ShapeUtil::MakeShapeWithDenseLayout(element_type, dims_,
layout_.minor_to_major());
return ShapeUtil::MakeShapeWithDenseLayout(
element_type, dims_, layout_.minor_to_major());
}

// Given that "this" is the target index of a reshape from `input_shape`
Expand Down
12 changes: 11 additions & 1 deletion third_party/xla/xla/shape_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ limitations under the License.
namespace xla {
namespace {

std::vector<DynExpr*> MakeExprs(absl::Span<const int64_t> dimensions) {
std::vector<DynExpr*> expressions;
expressions.reserve(dimensions.size());
for (int64_t d : dimensions) {
expressions.push_back(DynExpr::_(d));
}
return expressions;
}

class ShapeTest : public ::testing::Test {
protected:
const Shape opaque_ = ShapeUtil::MakeOpaqueShape();
Expand All @@ -40,7 +49,8 @@ class ShapeTest : public ::testing::Test {
const Shape matrix_ = ShapeUtil::MakeShape(U32, {1, 2});
const Shape matrix2_ =
ShapeUtil::MakeShapeWithDenseLayout(S32, {3, 4}, {0, 1});
const Shape matrix_buffer_ = ShapeUtil::MakeBufferShape(S32, {3, 4});
const Shape matrix_buffer_ =
ShapeUtil::MakeBufferShape(S32, {3, 4});
const Shape tuple_ =
ShapeUtil::MakeTupleShape({opaque_, scalar_, matrix_, matrix2_});
const Shape nested_tuple_ =
Expand Down
Loading