Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions lib/PTO/IR/PTO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1765,6 +1765,27 @@ LogicalResult TStoreOp::verify() {
return failure();
}
}

// Keep TSTORE contract explicit: destination tensor partition shape must
// match source tile valid_shape on every statically-known dimension.
auto dstShape = dstPart.getShape();
if (dstShape.size() != srcValid.size()) {
emitOpError() << "expects dst rank (" << dstShape.size()
<< ") to match src valid_shape rank (" << srcValid.size()
<< ")";
return failure();
}
for (auto [idx, dims] : llvm::enumerate(llvm::zip(dstShape, srcValid))) {
auto [dstDim, srcValidDim] = dims;
if (dstDim == ShapedType::kDynamic || srcValidDim == ShapedType::kDynamic)
continue;
if (dstDim != srcValidDim) {
emitOpError() << "expects dst shape[" << idx
<< "] to match src valid_shape[" << idx << "] ("
<< srcValidDim << "), but got " << dstDim;
return failure();
}
}
return std::make_pair(srcTile, dstPart);
};

Expand Down
15 changes: 15 additions & 0 deletions test/basic/tstore_verify_shape_valid_mismatch_invalid.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: ptoas %s 2>&1 | FileCheck %s

module {
func.func @tstore_shape_valid_mismatch_invalid(
%part : !pto.partition_tensor_view<64x64xf32>) {
%tile = pto.alloc_tile
: !pto.tile_buf<loc=vec, dtype=f32, rows=64, cols=64, v_row=32, v_col=32, blayout=row_major, slayout=none_box, fractal=512, pad=0>

pto.tstore ins(%tile : !pto.tile_buf<loc=vec, dtype=f32, rows=64, cols=64, v_row=32, v_col=32, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
outs(%part : !pto.partition_tensor_view<64x64xf32>)
return
}
}

// CHECK: error: 'pto.tstore' op expects dst shape[0] to match src valid_shape[0]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The FileCheck directive is a bit too general. It would be better to check for the specific values in the error message to make the test more robust. This ensures that the verifier is not only firing but also reporting the correct expected and actual dimension sizes.

// CHECK: error: 'pto.tstore' op expects dst shape[0] to match src valid_shape[0] (32), but got 64

Loading