Skip to content

[Pass] InitMemRef fragments dynamic-shape Var pointers, breaking roundtrip structural equality #970

@huxinyuan1215

Description

@huxinyuan1215

Category

Bug (Pass)

Component

Passes / Transforms — InitMemRef

Description

InitMemRef pass fragments dynamic-shape Var pointers. Before the pass, a single Var("M") shared_ptr is used across all param types, return types, and body. After the pass, each occurrence gets a different Var pointer with the same name and type, breaking structural_equal bijection.

This caused the RoundtripInstrument to suppress Variable pointer mismatch errors via string matching in instruments.py.

Root Cause

In ProcessNormalVar() (src/ir/transforms/init_memref.cpp), when a Var has a non-ShapedType (e.g. ScalarType(INDEX) for dynamic-shape dimensions like M, N), the function still creates a new Var via make_shared<Var>(...) even though the type is unchanged. This new pointer is cached in var_map_, but every call to CloneTypeWithMemRefAndRemapExprs on a param/body type re-encounters the original Var and creates a fresh copy.

Fix

Return the original Var pointer for non-ShapedType variables in ProcessNormalVar(), since they don't need MemRef initialization:

} else {
    return var;  // Preserve pointer identity for scalar/index Vars
}

Location

  • src/ir/transforms/init_memref.cpp:162-176ProcessNormalVar()
  • python/pypto/ir/instruments.py:100-107 — suppression to be removed

Priority

Medium

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    Status

    No status

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions