Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions python/pypto/ir/instruments.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,6 @@ def make_roundtrip_instrument() -> _passes.CallbackInstrument:
with SSA ``iter_args`` after ``ConvertToSSA``) have no valid Python DSL syntax.
The instrument cannot roundtrip what it cannot print; it warns and skips.

- **Variable pointer mismatch**: Dynamic-shape ``Var`` nodes (e.g. ``M``
in ``pl.Tensor[[M, N], pl.FP32]``) appear in multiple places (params,
return type, body). The original IR shares a single ``Var`` pointer
across all occurrences, but the parser may create separate ``Var``
objects for each occurrence. ``structural_equal`` uses pointer-based
bijection and detects this as a mismatch. This is a parser limitation
— it should reuse the same ``Var`` object for same-named dynamic-shape
parameters across all scopes.

Returns:
A ``CallbackInstrument`` named ``"RoundtripInstrument"``.
"""
Expand Down Expand Up @@ -97,14 +88,6 @@ def _after_pass(pass_obj: _passes.Pass, program: _ir.Program) -> None:
_ir.assert_structural_equal(program, reparsed)
except Exception as exc:
error_msg = str(exc)
# Variable pointer mismatch: dynamic-shape Var nodes (e.g. M in
# Tensor[[M, N], FP32]) share a single pointer in the original IR,
# but the parser may create separate Var objects for each occurrence.
# The bijection in structural_equal detects this as a mismatch.
# TODO(#929): fix the parser to reuse same-named dynamic-shape Var
# objects across param types, return types, and body — then remove.
if "Variable pointer mismatch" in error_msg:
return
raise RuntimeError(
f"[RoundtripInstrument] Structural equality failed after pass '{pass_name}'.\n"
f"\n"
Expand Down
5 changes: 5 additions & 0 deletions src/ir/transforms/init_memref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ class InitMemRefMutator : public IRMutator {
auto memref = CreateMemRef(shaped_type, var, memory_space);
new_type = CloneTypeWithMemRefAndRemapExprs(
var_expr->GetType(), memref, [this](const ExprPtr& expr) { return VisitExpr(expr); }, memory_space);
} else {
// Non-shaped types (e.g. ScalarType for dynamic-shape dimensions like M, N)
// don't need MemRef initialization — return the original Var to preserve
// pointer identity across all type annotations that reference it.
return var;
}

return std::make_shared<Var>(var->name_hint_, new_type, var->span_);
Expand Down
Loading