diff --git a/src/Lean/Elab/PreDefinition/Basic.lean b/src/Lean/Elab/PreDefinition/Basic.lean index 812e56198b0a..f8c39a9bfe5d 100644 --- a/src/Lean/Elab/PreDefinition/Basic.lean +++ b/src/Lean/Elab/PreDefinition/Basic.lean @@ -11,6 +11,7 @@ public import Lean.Util.NumApps public import Lean.Meta.Eqns public import Lean.Elab.RecAppSyntax public import Lean.Elab.DefView +import Lean.Meta.InstMVarsAll public section @@ -48,7 +49,7 @@ Applies `Lean.instantiateMVars` to the types of values of each predefinition. -/ def instantiateMVarsAtPreDecls (preDefs : Array PreDefinition) : TermElabM (Array PreDefinition) := preDefs.mapM fun preDef => do - pure { preDef with type := (← instantiateMVars preDef.type), value := (← instantiateMVars preDef.value) } + pure { preDef with type := (← instantiateAllMVars preDef.type), value := (← instantiateAllMVars preDef.value) } /-- Applies `Lean.Elab.Term.levelMVarToParam` to the types of each predefinition. diff --git a/src/Lean/Meta/InstMVarsAll.lean b/src/Lean/Meta/InstMVarsAll.lean new file mode 100644 index 000000000000..02b8efb92668 --- /dev/null +++ b/src/Lean/Meta/InstMVarsAll.lean @@ -0,0 +1,50 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Joachim Breitner +-/ + +module + +prelude +public import Lean.Meta.Basic + +namespace Lean.Meta + +@[extern "lean_instantiate_expr_mvars_original"] +private opaque instantiateMVarsOriginalImp (mctx : MetavarContext) (e : Expr) : + MetavarContext × Expr + +@[extern "lean_instantiate_expr_mvars_all"] +private opaque instantiateAllMVarsImp (mctx : MetavarContext) (e : Expr) : + MetavarContext × Expr + +@[extern "lean_instantiate_expr_mvars_all_sharing"] +private opaque instantiateAllMVarsSharingImp (mctx : MetavarContext) (e : Expr) : + MetavarContext × Expr + +/-- The original single-pass `instantiateMVars` implementation, exposed for benchmarking + independently of which implementation is the default. -/ +public def instantiateMVarsOriginal (e : Expr) : MetaM Expr := do + if !e.hasMVar then return e + let (mctx, eNew) := instantiateMVarsOriginalImp (← getMCtx) e + modifyMCtx fun _ => mctx; return eNew + +/-- Like `instantiateMVars` but uses a fused two-pass approach. + Pass 1 resolves direct mvar assignments with write-back. + Pass 2 resolves delayed assignments with a fused fvar substitution, + avoiding separate `replace_fvars` calls. Preserves sharing using + a flat cache with lazy staleness detection via persistent scope + generation snapshots. -/ +public def instantiateAllMVars (e : Expr) : MetaM Expr := do + if !e.hasMVar then return e + let (mctx, eNew) := instantiateAllMVarsImp (← getMCtx) e + modifyMCtx fun _ => mctx; return eNew + +/-- Alias for `instantiateAllMVars`. -/ +public def instantiateAllMVarsSharing (e : Expr) : MetaM Expr := do + if !e.hasMVar then return e + let (mctx, eNew) := instantiateAllMVarsSharingImp (← getMCtx) e + modifyMCtx fun _ => mctx; return eNew + +end Lean.Meta diff --git a/src/Lean/MetavarContext.lean b/src/Lean/MetavarContext.lean index 4ae7f52cf42c..2d33ec3d0992 100644 --- a/src/Lean/MetavarContext.lean +++ b/src/Lean/MetavarContext.lean @@ -400,6 +400,12 @@ def MetavarContext.getDelayedMVarAssignmentCore? (mctx : MetavarContext) (mvarId def MetavarContext.getDelayedMVarAssignmentExp (mctx : MetavarContext) (mvarId : MVarId) : Option DelayedMetavarAssignment := mctx.dAssignment.find? mvarId +@[export lean_delayed_mvar_assignment_fvars] +def DelayedMetavarAssignment.fvarsExp (d : DelayedMetavarAssignment) : Array Expr := d.fvars + +@[export lean_delayed_mvar_assignment_mvar_id_pending] +def DelayedMetavarAssignment.mvarIdPendingExp (d : DelayedMetavarAssignment) : MVarId := d.mvarIdPending + def getDelayedMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : MVarId) : m (Option DelayedMetavarAssignment) := return (← getMCtx).getDelayedMVarAssignmentCore? mvarId @@ -1483,6 +1489,24 @@ def levelMVarToParam (mctx : MetavarContext) (alreadyUsedPred : Name → Bool) ( def getExprAssignmentDomain (mctx : MetavarContext) : Array MVarId := mctx.eAssignment.foldl (init := #[]) fun a mvarId _ => Array.push a mvarId +/-- +Abstract the given fvars from an unassigned mvar by creating a delayed-assigned mvar. +Returns `(mctx', result)` where `result` has the fvars replaced by `values`. +This can be used when encountering an unassigned mvar with an active fvar +substitution. +-/ +@[export lean_abstract_mvar_fvars] +def abstractMVarFVars (mctx : MetavarContext) (mvarId : MVarId) (fvars : Array Expr) (values : Array Expr) : MetavarContext × Expr := + match mctx.decls.find? mvarId with + | none => (mctx, mkMVar mvarId) + | some _mvarDecl => + let ngen : NameGenerator := { namePrefix := `_noUpdate, idx := mctx.mvarCounter } + let ctx : MkBinding.Context := { quotContext := `_root_, preserveOrder := false } + let state : MkBinding.State := { mctx, nextMacroScope := 0, ngen } + match (MkBinding.elimMVarDeps fvars (mkMVar mvarId) ctx).run state with + | .ok e s => (s.mctx, Expr.replaceFVars e fvars values) + | .error _ s => (s.mctx, mkMVar mvarId) + end MetavarContext namespace MVarId diff --git a/src/kernel/CMakeLists.txt b/src/kernel/CMakeLists.txt index 913f16723535..d2dd7dc03940 100644 --- a/src/kernel/CMakeLists.txt +++ b/src/kernel/CMakeLists.txt @@ -19,4 +19,5 @@ add_library( inductive.cpp trace.cpp instantiate_mvars.cpp + instantiate_mvars_all.cpp ) diff --git a/src/kernel/instantiate_mvars.cpp b/src/kernel/instantiate_mvars.cpp index b49d74aa5cb5..d923561f0637 100644 --- a/src/kernel/instantiate_mvars.cpp +++ b/src/kernel/instantiate_mvars.cpp @@ -100,6 +100,8 @@ extern "C" LEAN_EXPORT object * lean_instantiate_level_mvars(object * m, object extern "C" object * lean_get_mvar_assignment(obj_arg mctx, obj_arg mid); extern "C" object * lean_get_delayed_mvar_assignment(obj_arg mctx, obj_arg mid); +extern "C" object * lean_delayed_mvar_assignment_fvars(obj_arg d); +extern "C" object * lean_delayed_mvar_assignment_mvar_id_pending(obj_arg d); extern "C" object * lean_assign_mvar(obj_arg mctx, obj_arg mid, obj_arg val); typedef object_ref delayed_assignment; @@ -116,6 +118,14 @@ option_ref get_delayed_mvar_assignment(metavar_ctx & mctx, n return option_ref(lean_get_delayed_mvar_assignment(mctx.to_obj_arg(), mid.to_obj_arg())); } +array_ref delayed_assignment_fvars(delayed_assignment const & d) { + return array_ref(lean_delayed_mvar_assignment_fvars(d.to_obj_arg())); +} + +name delayed_assignment_mvar_id_pending(delayed_assignment const & d) { + return name(lean_delayed_mvar_assignment_mvar_id_pending(d.to_obj_arg())); +} + expr replace_fvars(expr const & e, array_ref const & fvars, expr const * rev_args) { size_t sz = fvars.size(); if (sz == 0) @@ -290,8 +300,8 @@ class instantiate_mvars_fn { metavariables, we replace the free variables `fvars` in `newVal` with the first `fvars.size` elements of `args`. */ - array_ref fvars(cnstr_get(d.get_val().raw(), 0), true); - name mid_pending(cnstr_get(d.get_val().raw(), 1), true); + array_ref fvars = delayed_assignment_fvars(d.get_val()); + name mid_pending = delayed_assignment_mvar_id_pending(d.get_val()); if (fvars.size() > get_app_num_args(e)) { /* We don't have sufficient arguments for instantiating the free variables `fvars`. @@ -362,7 +372,8 @@ class instantiate_mvars_fn { expr operator()(expr const & e) { return visit(e); } }; -extern "C" LEAN_EXPORT object * lean_instantiate_expr_mvars(object * m, object * e) { +/* The original single-pass implementation, exposed for benchmarking. */ +extern "C" LEAN_EXPORT object * lean_instantiate_expr_mvars_original(object * m, object * e) { metavar_ctx mctx(m); expr e_new = instantiate_mvars_fn(mctx)(expr(e)); object * r = alloc_cnstr(0, 2, 0); @@ -370,4 +381,10 @@ extern "C" LEAN_EXPORT object * lean_instantiate_expr_mvars(object * m, object * cnstr_set(r, 1, e_new.steal()); return r; } + +extern "C" LEAN_EXPORT object * lean_instantiate_expr_mvars_all_sharing(object * m, object * e); + +extern "C" LEAN_EXPORT object * lean_instantiate_expr_mvars(object * m, object * e) { + return lean_instantiate_expr_mvars_all_sharing(m, e); +} } diff --git a/src/kernel/instantiate_mvars_all.cpp b/src/kernel/instantiate_mvars_all.cpp new file mode 100644 index 000000000000..17a4bfdbfd5c --- /dev/null +++ b/src/kernel/instantiate_mvars_all.cpp @@ -0,0 +1,778 @@ +/* +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Authors: Joachim Breitner +*/ +#include +#include +#include "util/name_set.h" +#include "util/name_hash_map.h" +#include "runtime/option_ref.h" +#include "runtime/array_ref.h" +#include "kernel/instantiate.h" +#include "kernel/expr.h" +#include "kernel/scope_cache.h" + +/* +This module provides an implementation of `instantiateMVars` with linear +complexity in the presence of nested delayed-assigned metavariables and +improved sharing. It proceeds in two passes. + +Terminology (for this file): + +* Direct MVar: an MVar that is not delayed-assigned. +* Pending MVar: the direct MVar stored in a `DelayedMetavarAssignment`. +* Assigned MVar: a direct MVar with an assignment, or a delayed-assigned MVar + with an assigned pending MVar. +* MVar DAG: the directed acyclic graph of MVars reachable from the expression. +* Resolvable MVar: an MVar where all MVars reachable from it (including itself) + are assigned. +* Updateable MVar: an assigned direct MVar, or a delayed-assigned MVar that is + resolvable but not reachable from any other resolvable delayed-assigned MVar. + +In the MVar DAG, the updateable delayed-assigned MVars form a cut with only +assigned MVars behind it and no resolvable delayed-assigned MVars before it. + +Pass 1 (`instantiate_direct_fn`): + Traverses all MVars and expressions reachable from the initial expression and + * instantiates all updateable direct MVars, updating their assignment with + its instantiation, + * instantiates all level MVars, + * determines if there are any updateable delayed-assigned MVars. + +Pass 2 (`instantiate_delayed_fn`): + Only run if there are updateable delayed-assigned MVars. Has an "outer" and + an "inner" mode, depending on whether it has crossed the updateable-MVar cut. + + In outer mode (empty substitution), all MVars are either unassigned direct + MVars (left alone), non-updateable delayed-assigned MVars (pending MVar + traversed in outer mode and updated with the result), or updateable + delayed-assigned MVars. + + When a delayed-assigned MVar is encountered, its MVar DAG is explored (via + `is_resolvable_pending`) to determine if it is resolvable (and thus + updateable). Results are cached across invocations. + + If it is updateable, the substitution is initialized from its arguments and + traversal continues with the value of its pending MVar in inner mode. + + In inner mode (non-empty substitution), all encountered delayed-assigned + MVars are, by construction, resolvable but not updateable. The substitution + is carried along and extended as we cross such MVars. Pending MVars of these + delayed-assigned MVars are NOT updated with the result (as the result is + valid only for this substitution, not in general). + + Applying the substitution in one go, rather than instantiating each + delayed-assigned MVar on its own from inside out, avoids the quadratic + overhead of that approach when there are long chains of delayed-assigned + MVars. + + A special-crafted caching data structure, the `scope_cache`, ensures that + sharing is preserved even across different delayed-assigned MVars (and hence + with different substitutions), when possible. +*/ + +namespace lean { +extern "C" object * lean_get_lmvar_assignment(obj_arg mctx, obj_arg mid); +extern "C" object * lean_assign_lmvar(obj_arg mctx, obj_arg mid, obj_arg val); +extern "C" object * lean_get_mvar_assignment(obj_arg mctx, obj_arg mid); +extern "C" object * lean_get_delayed_mvar_assignment(obj_arg mctx, obj_arg mid); +extern "C" object * lean_delayed_mvar_assignment_fvars(obj_arg d); +extern "C" object * lean_delayed_mvar_assignment_mvar_id_pending(obj_arg d); +extern "C" object * lean_assign_mvar(obj_arg mctx, obj_arg mid, obj_arg val); +typedef object_ref metavar_ctx; +typedef object_ref delayed_assignment; + +static void assign_lmvar(metavar_ctx & mctx, name const & mid, level const & l) { + object * r = lean_assign_lmvar(mctx.steal(), mid.to_obj_arg(), l.to_obj_arg()); + mctx.set_box(r); +} + +static option_ref get_lmvar_assignment(metavar_ctx & mctx, name const & mid) { + return option_ref(lean_get_lmvar_assignment(mctx.to_obj_arg(), mid.to_obj_arg())); +} + +static void assign_mvar(metavar_ctx & mctx, name const & mid, expr const & e) { + object * r = lean_assign_mvar(mctx.steal(), mid.to_obj_arg(), e.to_obj_arg()); + mctx.set_box(r); +} + +static option_ref get_mvar_assignment(metavar_ctx & mctx, name const & mid) { + return option_ref(lean_get_mvar_assignment(mctx.to_obj_arg(), mid.to_obj_arg())); +} + +static option_ref get_delayed_mvar_assignment(metavar_ctx & mctx, name const & mid) { + return option_ref(lean_get_delayed_mvar_assignment(mctx.to_obj_arg(), mid.to_obj_arg())); +} + +static array_ref delayed_assignment_fvars(delayed_assignment const & d) { + return array_ref(lean_delayed_mvar_assignment_fvars(d.to_obj_arg())); +} + +static name delayed_assignment_mvar_id_pending(delayed_assignment const & d) { + return name(lean_delayed_mvar_assignment_mvar_id_pending(d.to_obj_arg())); +} + +/* Level metavariable instantiation. */ +class instantiate_lmvars_all_fn { + metavar_ctx & m_mctx; + lean::unordered_map m_cache; + std::vector m_saved; + + inline level cache(level const & l, level r, bool shared) { + if (shared) { + m_cache.insert(mk_pair(l.raw(), r)); + } + return r; + } +public: + instantiate_lmvars_all_fn(metavar_ctx & mctx):m_mctx(mctx) {} + level visit(level const & l) { + if (!has_mvar(l)) + return l; + bool shared = false; + if (is_shared(l)) { + auto it = m_cache.find(l.raw()); + if (it != m_cache.end()) { + return it->second; + } + shared = true; + } + switch (l.kind()) { + case level_kind::Succ: + return cache(l, update_succ(l, visit(succ_of(l))), shared); + case level_kind::Max: case level_kind::IMax: + return cache(l, update_max(l, visit(level_lhs(l)), visit(level_rhs(l))), shared); + case level_kind::Zero: case level_kind::Param: + lean_unreachable(); + case level_kind::MVar: { + option_ref r = get_lmvar_assignment(m_mctx, mvar_id(l)); + if (!r) { + return l; + } else { + level a(r.get_val()); + if (!has_mvar(a)) { + return a; + } else { + level a_new = visit(a); + if (!is_eqp(a, a_new)) { + m_saved.push_back(a); + assign_lmvar(m_mctx, mvar_id(l), a_new); + } + return a_new; + } + } + }} + } + level operator()(level const & l) { return visit(l); } +}; + +/* ============================================================================ + Pass 1: Instantiate updateable direct MVars with write-back. + For delayed-assigned MVars, pre-normalize the pending MVar's value + (resolving its direct MVar chains) but leave the delayed-assigned MVar + application in the expression. Also instantiates level MVars. + Unassigned MVars are left in place. + ============================================================================ */ + +class instantiate_direct_fn { + metavar_ctx & m_mctx; + instantiate_lmvars_all_fn m_level_fn; + name_set m_already_normalized; + /* Set to true when a delayed-assigned MVar with an assigned pending MVar + is encountered. Pass 2 is needed to resolve or write back such MVars. */ + bool m_has_updateable_delayed; + + lean::unordered_map m_cache; + std::vector m_saved; + + level visit_level(level const & l) { + return m_level_fn(l); + } + + levels visit_levels(levels const & ls) { + buffer lsNew; + for (auto const & l : ls) + lsNew.push_back(visit_level(l)); + return levels(lsNew); + } + + inline expr cache(expr const & e, expr r, bool shared) { + if (shared) { + m_cache.insert(mk_pair(e.raw(), r)); + } + return r; + } + + /* Get and normalize an updateable direct MVar's assignment. Write back the + normalized value. */ + optional get_assignment(name const & mid) { + option_ref r = get_mvar_assignment(m_mctx, mid); + if (!r) { + return optional(); + } + expr a(r.get_val()); + if (!has_mvar(a) || m_already_normalized.contains(mid)) { + return optional(a); + } + m_already_normalized.insert(mid); + expr a_new = visit(a); + if (!is_eqp(a, a_new)) { + m_saved.push_back(a); + assign_mvar(m_mctx, mid, a_new); + } + return optional(a_new); + } + + expr visit_app_default(expr const & e) { + buffer args; + expr const * curr = &e; + while (is_app(*curr)) { + args.push_back(visit(app_arg(*curr))); + curr = &app_fn(*curr); + } + lean_assert(!is_mvar(*curr)); + expr f = visit(*curr); + return mk_rev_app(f, args.size(), args.data()); + } + + expr visit_mvar_app_args(expr const & e) { + buffer args; + expr const * curr = &e; + while (is_app(*curr)) { + args.push_back(visit(app_arg(*curr))); + curr = &app_fn(*curr); + } + lean_assert(is_mvar(*curr)); + return mk_rev_app(*curr, args.size(), args.data()); + } + + expr visit_args_and_beta(expr const & f_new, expr const & e, buffer & args) { + expr const * curr = &e; + while (is_app(*curr)) { + args.push_back(visit(app_arg(*curr))); + curr = &app_fn(*curr); + } + bool preserve_data = false; + bool zeta = true; + return apply_beta(f_new, args.size(), args.data(), preserve_data, zeta); + } + + expr visit_app(expr const & e) { + expr const & f = get_app_fn(e); + if (!is_mvar(f)) { + return visit_app_default(e); + } + name const & mid = mvar_name(f); + /* Direct MVar assignment takes precedence. */ + if (auto f_new = get_assignment(mid)) { + buffer args; + return visit_args_and_beta(*f_new, e, args); + } + /* Check for delayed-assigned MVar. */ + option_ref d = get_delayed_mvar_assignment(m_mctx, mid); + if (d) { + /* Pre-normalize the pending MVar's value so pass 2 finds it ready. + Only trigger pass 2 if the pending MVar is actually assigned; + unassigned pending MVars will clearly fail the resolvability check. */ + name mid_pending = delayed_assignment_mvar_id_pending(d.get_val()); + if (get_assignment(mid_pending)) + m_has_updateable_delayed = true; + } + return visit_mvar_app_args(e); + } + + expr visit_mvar(expr const & e) { + name const & mid = mvar_name(e); + if (auto r = get_assignment(mid)) { + return *r; + } + /* Not a direct MVar with assignment. Check if delayed-assigned. */ + option_ref d = get_delayed_mvar_assignment(m_mctx, mid); + if (d) { + name mid_pending = delayed_assignment_mvar_id_pending(d.get_val()); + if (get_assignment(mid_pending)) + m_has_updateable_delayed = true; + } + return e; + } + +public: + instantiate_direct_fn(metavar_ctx & mctx) + : m_mctx(mctx), m_level_fn(mctx), m_has_updateable_delayed(false) {} + bool has_updateable_delayed() const { return m_has_updateable_delayed; } + + expr visit(expr const & e) { + if (!has_mvar(e)) + return e; + bool shared = false; + if (is_shared(e)) { + auto it = m_cache.find(e.raw()); + if (it != m_cache.end()) { + return it->second; + } + shared = true; + } + + switch (e.kind()) { + case expr_kind::BVar: + case expr_kind::Lit: case expr_kind::FVar: + lean_unreachable(); + case expr_kind::Sort: + return cache(e, update_sort(e, visit_level(sort_level(e))), shared); + case expr_kind::Const: + return cache(e, update_const(e, visit_levels(const_levels(e))), shared); + case expr_kind::MVar: + return visit_mvar(e); + case expr_kind::MData: + return cache(e, update_mdata(e, visit(mdata_expr(e))), shared); + case expr_kind::Proj: + return cache(e, update_proj(e, visit(proj_expr(e))), shared); + case expr_kind::App: + return cache(e, visit_app(e), shared); + case expr_kind::Pi: case expr_kind::Lambda: + return cache(e, update_binding(e, visit(binding_domain(e)), visit(binding_body(e))), shared); + case expr_kind::Let: + return cache(e, update_let(e, visit(let_type(e)), visit(let_value(e)), visit(let_body(e))), shared); + } + } + + expr operator()(expr const & e) { return visit(e); } +}; + +/* ============================================================================ + Pass 2: Resolve delayed-assigned MVars with fused fvar substitution. + Direct MVar chains have been pre-resolved by pass 1. + + Write-back behavior: + + Delayed-assigned MVars form a dependency tree: each delayed-assigned MVar's + pending MVar value may reference other delayed-assigned MVars. Some subtrees + of this tree are fully resolvable (all delayed-assigned MVars within are + resolvable), while others are not. + + Pass 2 fully resolves every maximal resolvable subtree. The roots of these + subtrees — updateable delayed-assigned MVars that are resolvable but whose + parent in the tree is not — form the updateable-MVar cut through the + dependency tree. Above the cut sit non-resolvable delayed-assigned MVars; + below the cut, everything is resolved. + + Pass 2 writes back the normalized pending MVar values of delayed-assigned + MVars above the cut (the non-resolvable ones whose children may have been + resolved). This is exactly the right set: these MVars are visited in outer + mode (empty fvar substitution), so their normalized values are suitable for + storing in the mctx. MVars below the cut are visited in inner mode + (non-empty substitution, fvars replaced by arguments), so their intermediate + values cannot be written back. + ============================================================================ */ + +struct fvar_subst_entry { + unsigned depth; + unsigned scope; + expr value; +}; + +class instantiate_delayed_fn { + metavar_ctx & m_mctx; + name_hash_map m_fvar_subst; + unsigned m_depth; + + /* Scope-aware cache for (ptr, depth) → expr with lazy staleness detection. */ + struct key_hasher { + std::size_t operator()(std::pair const & p) const { + return hash((size_t)p.first >> 3, p.second); + } + }; + typedef std::pair cache_key; + scope_cache m_cache; + + /* After visit() returns, this holds the maximum fvar-substitution + scope that contributed to the result — i.e., the outermost scope at which the + result is valid and can be cached. Updated monotonically (via max) through + the save/reset/restore pattern in visit(). */ + unsigned m_result_scope; + + /* Write-back support: in outer mode, normalize and write back direct MVar + assignments. Downstream code (e.g. MutualDef.mkInitialUsedFVarsMap) reads + stored assignments and expects inner delayed-assigned MVars to be resolved. */ + name_set m_already_normalized; + std::vector m_saved; + + /* Resolvability caches — persistent across all delayed-assigned MVar + resolutions. A pending MVar is resolvable if its assigned value + (normalized by pass 1) would become MVar-free after resolution: all + remaining MVars must be delayed-assigned MVars in app position with + enough arguments, whose own pending MVars are also resolvable. */ + lean::unordered_map m_resolvable_expr_cache; + name_hash_map m_resolvable_pending_cache; /* 0 = in-progress, 1 = yes, 2 = no */ + + bool is_resolvable_pending(name const & pending) { + auto it = m_resolvable_pending_cache.find(pending); + if (it != m_resolvable_pending_cache.end()) + return it->second == 1; + /* Mark in-progress (cycle guard — shouldn't happen). */ + m_resolvable_pending_cache[pending] = 0; + option_ref r = get_mvar_assignment(m_mctx, pending); + if (!r) { + m_resolvable_pending_cache[pending] = 2; + return false; + } + bool ok = is_resolvable_expr(expr(r.get_val())); + m_resolvable_pending_cache[pending] = ok ? 1 : 2; + return ok; + } + + bool is_resolvable_expr(expr const & e) { + if (!has_expr_mvar(e)) return true; + if (is_shared(e)) { + auto it = m_resolvable_expr_cache.find(e.raw()); + if (it != m_resolvable_expr_cache.end()) + return it->second; + } + bool r = is_resolvable_expr_core(e); + if (is_shared(e)) + m_resolvable_expr_cache[e.raw()] = r; + return r; + } + + bool is_resolvable_expr_core(expr const & e) { + switch (e.kind()) { + case expr_kind::MVar: + /* Bare MVar — direct MVar assignments were resolved by pass 1. Stuck. */ + return false; + case expr_kind::App: { + expr const & f = get_app_fn(e); + if (is_mvar(f)) { + name const & mid = mvar_name(f); + option_ref d = + get_delayed_mvar_assignment(m_mctx, mid); + if (!d) return false; + array_ref fvars = delayed_assignment_fvars(d.get_val()); + if (fvars.size() > get_app_num_args(e)) + return false; /* not enough args */ + name mid_pending = delayed_assignment_mvar_id_pending(d.get_val()); + if (!is_resolvable_pending(mid_pending)) + return false; + /* Check args too. */ + expr const * curr = &e; + while (is_app(*curr)) { + if (!is_resolvable_expr(app_arg(*curr))) + return false; + curr = &app_fn(*curr); + } + return true; + } + return is_resolvable_expr(app_fn(e)) && is_resolvable_expr(app_arg(e)); + } + case expr_kind::Lambda: case expr_kind::Pi: + return is_resolvable_expr(binding_domain(e)) && + is_resolvable_expr(binding_body(e)); + case expr_kind::Let: + return is_resolvable_expr(let_type(e)) && + is_resolvable_expr(let_value(e)) && + is_resolvable_expr(let_body(e)); + case expr_kind::MData: + return is_resolvable_expr(mdata_expr(e)); + case expr_kind::Proj: + return is_resolvable_expr(proj_expr(e)); + default: + return true; + } + } + + /* Outer mode: no fvar substitution active; inner mode: inside a + resolvable delayed-assigned MVar with fvars mapped to arguments. */ + bool in_outer_mode() const { + return m_fvar_subst.empty(); + } + + optional lookup_fvar(name const & fid) { + auto it = m_fvar_subst.find(fid); + if (it == m_fvar_subst.end()) + return optional(); + m_result_scope = std::max(m_result_scope, it->second.scope); + unsigned d = m_depth - it->second.depth; + if (d == 0) + return optional(it->second.value); + return optional(lift_loose_bvars(it->second.value, d)); + } + + /* Get a direct MVar assignment. Visit it to resolve delayed-assigned + MVars and apply the fvar substitution. + In outer mode, normalize and write back the result to the mctx. + Downstream code (e.g. MutualDef.mkInitialUsedFVarsMap) reads stored + assignments and expects inner delayed-assigned MVars to be resolved. + In inner mode, no write-back: the result contains fvar-substituted + terms not suitable for the mctx. */ + optional get_assignment(name const & mid) { + option_ref r = get_mvar_assignment(m_mctx, mid); + if (!r) + return optional(); + expr a(r.get_val()); + if (in_outer_mode()) { + if (m_already_normalized.contains(mid)) + return optional(a); + m_already_normalized.insert(mid); + expr a_new = visit(a); + if (!is_eqp(a, a_new)) { + m_saved.push_back(a); + assign_mvar(m_mctx, mid, a_new); + } + return optional(a_new); + } else { + return optional(visit(a)); + } + } + + expr visit_app_default(expr const & e) { + buffer args; + expr const * curr = &e; + while (is_app(*curr)) { + args.push_back(visit(app_arg(*curr))); + curr = &app_fn(*curr); + } + lean_assert(!is_mvar(*curr)); + expr f = visit(*curr); + return mk_rev_app(f, args.size(), args.data()); + } + + expr visit_mvar_app_args(expr const & e) { + buffer args; + expr const * curr = &e; + while (is_app(*curr)) { + args.push_back(visit(app_arg(*curr))); + curr = &app_fn(*curr); + } + lean_assert(is_mvar(*curr)); + return mk_rev_app(*curr, args.size(), args.data()); + } + + expr visit_delayed(array_ref const & fvars, name const & mid_pending, + expr const & e) { + buffer args; + expr const * curr = &e; + while (is_app(*curr)) { + args.push_back(visit(app_arg(*curr))); + curr = &app_fn(*curr); + } + + size_t fvar_count = fvars.size(); + size_t extra_count = args.size() - fvar_count; + + /* Push a new scope and extend the fvar substitution. */ + m_cache.push(); + struct saved_entry { name key; bool had_old; fvar_subst_entry old; }; + std::vector saved_entries; + saved_entries.reserve(fvar_count); + for (size_t i = 0; i < fvar_count; i++) { + name const & fid = fvar_name(fvars[i]); + auto old_it = m_fvar_subst.find(fid); + if (old_it != m_fvar_subst.end()) { + saved_entries.push_back({fid, true, old_it->second}); + } else { + saved_entries.push_back({fid, false, {0, 0, expr()}}); + } + m_fvar_subst[fid] = {m_depth, m_cache.scope(), args[args.size() - 1 - i]}; + } + + /* Get the pending MVar's value directly — it must be assigned (pass 1 + pre-normalized it). No write-back: we are in inner mode. */ + option_ref pending_val = get_mvar_assignment(m_mctx, mid_pending); + lean_assert(!!pending_val); + expr val_new = visit(expr(pending_val.get_val())); + + /* Pop scope; stale entries are detected by generation mismatch on lookup. */ + m_cache.pop(); + + /* Restore the fvar substitution. */ + for (auto & se : saved_entries) { + if (!se.had_old) { + m_fvar_subst.erase(se.key); + } else { + m_fvar_subst[se.key] = se.old; + } + } + + /* Use apply_beta instead of mk_rev_app: pass 1's beta-reduction may have + changed delayed-assigned MVar arguments (e.g., substituting a bvar with a + concrete value), so the resolved pending MVar value may be a lambda that + needs beta-reduction with the extra args. */ + bool preserve_data = false; + bool zeta = true; + return apply_beta(val_new, extra_count, args.data(), preserve_data, zeta); + } + + expr visit_app(expr const & e) { + expr const & f = get_app_fn(e); + if (!is_mvar(f)) { + return visit_app_default(e); + } + name const & mid = mvar_name(f); + /* Direct MVar assignments were resolved by pass 1. */ + lean_assert(!get_mvar_assignment(m_mctx, mid)); + /* Check for delayed-assigned MVar. */ + option_ref d = get_delayed_mvar_assignment(m_mctx, mid); + if (!d) { + return visit_mvar_app_args(e); + } + array_ref fvars = delayed_assignment_fvars(d.get_val()); + name mid_pending = delayed_assignment_mvar_id_pending(d.get_val()); + if (fvars.size() > get_app_num_args(e)) { + return visit_mvar_app_args(e); + } + if (is_resolvable_pending(mid_pending)) { + /* Updateable delayed-assigned MVar: cross the cut into inner mode. */ + return visit_delayed(fvars, mid_pending, e); + } else { + /* Non-resolvable delayed-assigned MVars only appear in outer mode: + inside a resolvable subtree all nested delayed-assigned MVars are + resolvable too. */ + lean_assert(in_outer_mode()); + /* Normalize the pending MVar's value for mctx write-back + (see write-back comment above). */ + (void)get_assignment(mid_pending); + return visit_mvar_app_args(e); + } + } + + expr visit_fvar(expr const & e) { + name const & fid = fvar_name(e); + if (auto r = lookup_fvar(fid)) { + return *r; + } + return e; + } + +public: + instantiate_delayed_fn(metavar_ctx & mctx) + : m_mctx(mctx), m_depth(0), m_result_scope(0) {} + + expr visit(expr const & e) { + if ((!has_fvar(e) || in_outer_mode()) && !has_expr_mvar(e)) + return e; + + bool shared = false; + if (is_shared(e)) { + if (auto r = m_cache.lookup(cache_key(e.raw(), m_depth), m_result_scope)) + return *r; + shared = true; + } + + /* Save and reset the result scope for this subtree. + After computing, cache_insert uses m_result_scope to place the entry + at the outermost valid scope level. Then we restore the parent's + watermark, taking the max with our contribution. */ + unsigned saved_result_scope = m_result_scope; + m_result_scope = 0; + + expr r; + switch (e.kind()) { + case expr_kind::BVar: + case expr_kind::Lit: + lean_unreachable(); + case expr_kind::FVar: + r = visit_fvar(e); + goto done; /* skip caching for fvars */ + case expr_kind::Sort: + r = update_sort(e, visit_level(sort_level(e))); + break; + case expr_kind::Const: + r = update_const(e, visit_levels(const_levels(e))); + break; + case expr_kind::MVar: + /* Bare MVars in pass 2 are unassigned direct MVars: direct MVar + assignments were resolved by pass 1, and resolvable pending MVar + values contain no bare unassigned MVars. They only appear in + outer mode (at the top level or during write-back normalization). */ + lean_assert(in_outer_mode()); + lean_assert(!get_mvar_assignment(m_mctx, mvar_name(e))); + r = e; + goto done; + case expr_kind::MData: + r = update_mdata(e, visit(mdata_expr(e))); + break; + case expr_kind::Proj: + r = update_proj(e, visit(proj_expr(e))); + break; + case expr_kind::App: + r = visit_app(e); + break; + case expr_kind::Pi: case expr_kind::Lambda: { + expr d = visit(binding_domain(e)); + m_depth++; + expr b = visit(binding_body(e)); + m_depth--; + r = update_binding(e, d, b); + break; + } + case expr_kind::Let: { + expr t = visit(let_type(e)); + expr v = visit(let_value(e)); + m_depth++; + expr b = visit(let_body(e)); + m_depth--; + r = update_let(e, t, v, b); + break; + } + } + if (shared) { + r = m_cache.insert(cache_key(e.raw(), m_depth), r, m_result_scope); + } + + done: + m_result_scope = std::max(saved_result_scope, m_result_scope); + return r; + } + + level visit_level(level const & l) { + /* Pass 2 does not handle level MVars — pass 1 already resolved them. + But we still need this for the visit_levels call in update_sort/update_const. + Since levels have no fvars, we can just return them as-is. */ + return l; + } + + levels visit_levels(levels const & ls) { + return ls; + } + + expr operator()(expr const & e) { return visit(e); } +}; + +/* ============================================================================ + Entry points: run pass 1 then pass 2. + ============================================================================ */ + +static object * run_instantiate_all(object * m, object * e) { + metavar_ctx mctx(m); + + /* Pass 1: instantiate updateable direct MVars, pre-normalize pending MVar values. */ + instantiate_direct_fn pass1(mctx); + expr e1 = pass1(expr(e)); + + /* Pass 2: resolve delayed-assigned MVars with fused fvar substitution. + Skip if pass 1 found no delayed-assigned MVars with assigned pending + MVars — none need resolution or write-back. */ + expr e2; + if (!pass1.has_updateable_delayed()) { + e2 = e1; + } else { + instantiate_delayed_fn pass2(mctx); + e2 = pass2(e1); + } + + /* (mctx, expr) */ + object * r = alloc_cnstr(0, 2, 0); + cnstr_set(r, 0, mctx.steal()); + cnstr_set(r, 1, e2.steal()); + return r; +} + +extern "C" LEAN_EXPORT object * lean_instantiate_expr_mvars_all(object * m, object * e) { + return run_instantiate_all(m, e); +} + +extern "C" LEAN_EXPORT object * lean_instantiate_expr_mvars_all_sharing(object * m, object * e) { + return run_instantiate_all(m, e); +} +} diff --git a/src/kernel/scope_cache.h b/src/kernel/scope_cache.h new file mode 100644 index 000000000000..3b0d6114bd95 --- /dev/null +++ b/src/kernel/scope_cache.h @@ -0,0 +1,176 @@ +/* +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Authors: Joachim Breitner +*/ +#pragma once +#include +#include +#include "util/alloc.h" +#include "runtime/optional.h" + +namespace lean { + +/* +Conceptually, the scope cache is a stack of `Key → (Value × Scope)` hashmaps. +The `Scope` is a counter indicating the lowest position in the stack for which +the entry is valid. + +Its purpose is to provide caching for an operation that: + * maintains scopes (e.g. local contexts, substitutions). Higher stack + positions correspond to inner, more local scopes. + * For a given key, the result may depend on all or part of that scope. + * At lookup time, it is not known whether the value for a key will depend on + all or part of the scope, so only entries for the current innermost scope + are considered. + * At insert time, it is known which outermost scope the result depends on + (the "result scope"), and the result is valid for all scopes between that + and the innermost scope. + +The operations are: + * push(): push a new (empty) hashmap onto the stack. + * pop(): pop the top hashmap from the stack. + * scope(): current size of the stack (i.e. the index of the innermost scope). + * lookup(key, result_scope): look up in the top hashmap, returning the value + and propagating its result scope into `result_scope` via max. + * insert(key, value, result_scope): insert a key-value pair into the hashmaps + in the stack at depths in the range from `result_scope` to `scope()`. If it + encounters an existing value along the way, uses and returns that value for + improved sharing. + +The implementation inverts the data structure to a hashmap of stacks for +efficiency. It uses a generation counter to assign a unique identifier to each +scope, and maintains a persistent linked list of these to represent the current +scope stack. Cache entries are not touched upon pop(); instead they are lazily +cleaned up when accessed (the `rewind` operation). Upon insert, instead of +duplicating the entry for all valid scopes, it stores one entry with the range +of scopes it is valid for. +*/ +template> +class scope_cache { + struct scope_gen_node { + unsigned gen; + scope_gen_node * tail; /* parent scope, or nullptr for scope -1 */ + }; + + struct cache_entry { + Value result; + unsigned scope_level; /* scope at which this entry is (currently) valid */ + scope_gen_node * scope_gens; /* snapshot of scope_gens list at store time */ + unsigned result_scope; /* maximum scope that contributed to the result */ + }; + + typedef lean::unordered_map, Hash> cache_map; + + cache_map m_cache; + std::deque m_gen_arena; + scope_gen_node * m_scope_gens_list; + unsigned m_gen_counter; + unsigned m_scope; + +public: + scope_cache() : m_scope_gens_list(nullptr), m_gen_counter(0), m_scope(0) { + m_gen_arena.push_back({0, nullptr}); + m_scope_gens_list = &m_gen_arena.back(); + } + + unsigned scope() const { return m_scope; } + + /* Enter a new scope. Bumps the generation counter so that stale + entries at the new scope level are detected on lookup. */ + void push() { + m_scope++; + m_gen_counter++; + m_gen_arena.push_back({m_gen_counter, m_scope_gens_list}); + m_scope_gens_list = &m_gen_arena.back(); + } + + /* Leave the current scope. Follows the tail of the persistent + generation list back to the parent scope. */ + void pop() { + m_scope--; + m_scope_gens_list = m_scope_gens_list->tail; + } + +private: + /* Lazily clean up the top of a per-key entry stack: degrade entries + whose scopes were popped and evict entries that are stale due to + popped result scopes or scope re-entry. After rewind, either the + stack is empty or its top entry satisfies scope_level <= m_scope + with a matching scope generation. */ + void rewind(std::vector & stack) { + while (!stack.empty()) { + auto & top = stack.back(); + /* Discard entries whose result depends on popped scopes. */ + if (top.result_scope > m_scope) { + stack.pop_back(); + continue; + } + /* Degrade: follow tail pointers for scopes that were popped. */ + while (top.scope_level > m_scope) { + top.scope_gens = top.scope_gens->tail; + top.scope_level--; + } + /* Check generation at scope_level. When scope_level < m_scope, + walk the current list down to scope_level first. */ + scope_gen_node * current_node = m_scope_gens_list; + for (unsigned i = m_scope; i > top.scope_level; i--) + current_node = current_node->tail; + if (top.scope_gens->gen == current_node->gen) return; + /* Generation mismatch: scope was re-entered. Walk both lists + in lockstep to find a valid level or exhaust to result_scope. */ + scope_gen_node * entry_node = top.scope_gens; + unsigned level = top.scope_level; + while (level > top.result_scope) { + entry_node = entry_node->tail; + current_node = current_node->tail; + level--; + if (entry_node->gen == current_node->gen) { + top.scope_level = level; + top.scope_gens = entry_node; + return; /* now scope_level < m_scope */ + } + } + /* No valid level found → discard. */ + stack.pop_back(); + } + } + +public: + /* Look up a cached result for the given key at the current scope. + On hit, updates `result_scope = max(result_scope, entry.result_scope)` + and returns the cached result. On miss, returns none. */ + optional lookup(Key const & key, unsigned & result_scope) { + auto it = m_cache.find(key); + if (it == m_cache.end()) return {}; + auto & stack = it->second; + rewind(stack); + if (stack.empty()) return {}; + auto & top = stack.back(); + if (top.scope_level != m_scope) return {}; + result_scope = std::max(result_scope, top.result_scope); + return optional(top.result); + } + + /* Insert a result for the given key at the current scope. + `result_scope` is the maximum scope that contributed to the result; + the entry is only valid when all scopes up to result_scope are unchanged. + If a valid entry with the same `result_scope` already exists, its value + is reused for sharing; the returned reference is the stored value. */ + Value const & insert(Key const & key, Value const & result, unsigned result_scope) { + auto & stack = m_cache[key]; + rewind(stack); + Value shared = result; + if (!stack.empty() && stack.back().result_scope == result_scope) { + shared = stack.back().result; + } + while (!stack.empty() && stack.back().scope_level >= result_scope) { + stack.pop_back(); + } + stack.push_back({std::move(shared), m_scope, m_scope_gens_list, result_scope}); + return stack.back().result; + } +}; + +} diff --git a/tests/bench/delayed_assign.lean b/tests/bench/delayed_assign.lean index 7417272c8252..b9ec12a34ff0 100644 --- a/tests/bench/delayed_assign.lean +++ b/tests/bench/delayed_assign.lean @@ -1,4 +1,7 @@ import Lean +import Lean.Meta.InstMVarsAll + +set_option maxHeartbeats 4000000 open Lean Meta @@ -19,18 +22,6 @@ partial def solve (mvarId : MVarId) : MetaM Unit := do else let [] ← mvarId.applyConst ``True.intro | failure -partial def runBench (name : String) (n : Nat) (mk : Nat → MetaM MVarId) : MetaM Unit := do - let mvarId ← mk n - let startTime ← IO.monoNanosNow - solve mvarId - let endTime ← IO.monoNanosNow - let ms := (endTime - startTime).toFloat / 1000000.0 - let startTime ← IO.monoNanosNow - discard <| instantiateMVars (mkMVar mvarId) - let endTime ← IO.monoNanosNow - let instMs := (endTime - startTime).toFloat / 1000000.0 - IO.println s!"{name}_{n}: {ms} ms, instantiateMVars: {instMs} ms" - def mkBench1 (n : Nat) : MetaM MVarId := do let type := mkType n return (← mkFreshExprSyntheticOpaqueMVar type).mvarId! @@ -45,12 +36,33 @@ where | 0 => mkResultType n | i+1 => .forallE `x Nat.mkType (mkAnd (mkType i) (mkLE (n - i - 1))) .default +/-- Run a single implementation on a fresh copy of the benchmark, return (result, time_ms). -/ +def runImpl (n : Nat) (f : Expr → MetaM Expr) : MetaM (Expr × Float) := do + let mvarId ← mkBench1 n + solve mvarId + let t0 ← IO.monoNanosNow + let r ← f (mkMVar mvarId) + let t1 ← IO.monoNanosNow + let ms := (t1 - t0).toFloat / 1000000.0 + return (r, ms) + partial def bench1 (n : Nat) : MetaM Unit := do - runBench "bench1" n mkBench1 + let (rDefault, msDefault) ← runImpl n instantiateMVars + let (rOriginal, msOriginal) ← runImpl n instantiateMVarsOriginal + let (rAll, msAll) ← runImpl n instantiateAllMVars + + -- Verify correctness + unless Expr.eqv rDefault rOriginal do + IO.println s!"ERROR: instantiateMVars vs Original differ for n={n}" + unless Expr.eqv rDefault rAll do + IO.println s!"ERROR: instantiateMVars vs AllMVars differ for n={n}" + + IO.println s!"bench1_{n}: Default {msDefault} ms, Original {msOriginal} ms, AllMVars {msAll} ms" run_meta do IO.println "Example (n = 5):" let ex ← (← mkBench1 5).getType IO.println s!"{← ppExpr ex}" + IO.println "" for i in [10, 20, 40, 80, 100, 200, 300, 400, 500] do bench1 i diff --git a/tests/elab/1179b.lean.out.expected b/tests/elab/1179b.lean.out.expected index d9985ac3a81f..ba4c34d35dc4 100644 --- a/tests/elab/1179b.lean.out.expected +++ b/tests/elab/1179b.lean.out.expected @@ -2,15 +2,13 @@ def Foo.bar.match_1.{u_1} : {l₂ : Nat} → (motive : Foo l₂ → Sort u_1) → (t₂ : Foo l₂) → ((s₁ : Foo l₂) → motive s₁.cons) → ((x : Foo l₂) → motive x) → motive t₂ := fun {l₂} motive t₂ h_1 h_2 => - (fun t₂_1 => - Foo.bar._sparseCasesOn_1 (motive := fun a x => l₂ = a → t₂ ≍ x → motive t₂) t₂_1 - (fun {l} t h => - Eq.ndrec (motive := fun {l} => (t : Foo l) → t₂ ≍ t.cons → motive t₂) - (fun t h => Eq.symm (eq_of_heq h) ▸ h_1 t) h t) - fun h h_3 => - Eq.ndrec (motive := fun a => (t₂_2 : Foo a) → Nat.hasNotBit 2 t₂_2.ctorIdx → t₂ ≍ t₂_2 → motive t₂) - (fun t₂_2 h h_4 => - Eq.ndrec (motive := fun t₂_3 => Nat.hasNotBit 2 t₂_3.ctorIdx → motive t₂) (fun h => h_2 t₂) (eq_of_heq h_4) - h) - h_3 t₂_1 h) - t₂ (Eq.refl l₂) (HEq.refl t₂) + Foo.bar._sparseCasesOn_1 (motive := fun a x => l₂ = a → t₂ ≍ x → motive t₂) t₂ + (fun {l} t h => + Eq.ndrec (motive := fun {l} => (t : Foo l) → t₂ ≍ t.cons → motive t₂) (fun t h => Eq.symm (eq_of_heq h) ▸ h_1 t) h + t) + (fun h h_3 => + Eq.ndrec (motive := fun a => (t₂_1 : Foo a) → Nat.hasNotBit 2 t₂_1.ctorIdx → t₂ ≍ t₂_1 → motive t₂) + (fun t₂_1 h h_4 => + Eq.ndrec (motive := fun t₂_2 => Nat.hasNotBit 2 t₂_2.ctorIdx → motive t₂) (fun h => h_2 t₂) (eq_of_heq h_4) h) + h_3 t₂ h) + (Eq.refl l₂) (HEq.refl t₂) diff --git a/tests/elab/depElim1.lean.out.expected b/tests/elab/depElim1.lean.out.expected index 9a84846e5bfe..6104b68b92d8 100644 --- a/tests/elab/depElim1.lean.out.expected +++ b/tests/elab/depElim1.lean.out.expected @@ -24,29 +24,28 @@ def elimTest2.{u_1, u_2} : (α : Type u_1) → (x : α) → (xs : Vec α n) → (y : α) → (ys : Vec α n) → motive (n + 1) (Vec.cons x xs) (Vec.cons y ys)) → motive n xs ys := fun α motive n xs ys h_1 h_2 => - (fun xs_1 => - Vec.casesOn (motive := fun a x => n = a → xs ≍ x → motive n xs ys) xs_1 - (fun h => - Eq.ndrec (motive := fun n => (xs ys : Vec α n) → xs ≍ Vec.nil → motive n xs ys) - (fun xs ys h => - ⋯ ▸ - Vec.casesOn (motive := fun a x => 0 = a → ys ≍ x → motive 0 Vec.nil ys) ys (fun h h_3 => ⋯ ▸ h_1 ()) - (fun {n} a a_1 h => False.elim ⋯) ⋯ ⋯) - ⋯ xs ys) - fun {n_1} a a_1 h => - Eq.ndrec (motive := fun n => (xs ys : Vec α n) → xs ≍ Vec.cons a a_1 → motive n xs ys) - (fun xs ys h => - ⋯ ▸ - Vec.casesOn (motive := fun a_2 x => n_1 + 1 = a_2 → ys ≍ x → motive (n_1 + 1) (Vec.cons a a_1) ys) ys - (fun h => False.elim ⋯) - (fun {n} a_2 a_3 h => - n_1.elimOffset n 1 h fun x => - Eq.ndrec (motive := fun {n} => - (a_4 : Vec α n) → ys ≍ Vec.cons a_2 a_4 → motive (n_1 + 1) (Vec.cons a a_1) ys) - (fun a_4 h => ⋯ ▸ h_2 n_1 a a_1 a_2 a_4) x a_3) - ⋯ ⋯) - ⋯ xs ys) - xs ⋯ ⋯ + Vec.casesOn (motive := fun a x => n = a → xs ≍ x → motive n xs ys) xs + (fun h => + Eq.ndrec (motive := fun n => (xs ys : Vec α n) → xs ≍ Vec.nil → motive n xs ys) + (fun xs ys h => + ⋯ ▸ + Vec.casesOn (motive := fun a x => 0 = a → ys ≍ x → motive 0 Vec.nil ys) ys (fun h h_3 => ⋯ ▸ h_1 ()) + (fun {n} a a_1 h => False.elim ⋯) ⋯ ⋯) + ⋯ xs ys) + (fun {n_1} a a_1 h => + Eq.ndrec (motive := fun n => (xs ys : Vec α n) → xs ≍ Vec.cons a a_1 → motive n xs ys) + (fun xs ys h => + ⋯ ▸ + Vec.casesOn (motive := fun a_2 x => n_1 + 1 = a_2 → ys ≍ x → motive (n_1 + 1) (Vec.cons a a_1) ys) ys + (fun h => False.elim ⋯) + (fun {n} a_2 a_3 h => + n_1.elimOffset n 1 h fun x => + Eq.ndrec (motive := fun {n} => + (a_4 : Vec α n) → ys ≍ Vec.cons a_2 a_4 → motive (n_1 + 1) (Vec.cons a a_1) ys) + (fun a_4 h => ⋯ ▸ h_2 n_1 a a_1 a_2 a_4) x a_3) + ⋯ ⋯) + ⋯ xs ys) + ⋯ ⋯ elimTest3 : forall (α : Type.{u_1}) (β : Type.{u_2}) (motive : (List.{u_1} α) -> (List.{u_2} β) -> Sort.{u_3}) (x : List.{u_1} α) (y : List.{u_2} β), (Unit -> (motive (List.nil.{u_1} α) (List.nil.{u_2} β))) -> (forall (a : α) (b : β), motive (List.cons.{u_1} α a (List.nil.{u_1} α)) (List.cons.{u_2} β b (List.nil.{u_2} β))) -> (forall (a₁ : α) (a₂ : α) (as : List.{u_1} α) (b₁ : β) (b₂ : β) (bs : List.{u_2} β), motive (List.cons.{u_1} α a₁ (List.cons.{u_1} α a₂ as)) (List.cons.{u_2} β b₁ (List.cons.{u_2} β b₂ bs))) -> (forall (as : List.{u_1} α) (bs : List.{u_2} β), motive as bs) -> (motive x y) def elimTest3.{u_1, u_2, u_3} : (α : Type u_1) → (β : Type u_2) → diff --git a/tests/lean/run/instantiateAllMVarsCrossScope.lean b/tests/lean/run/instantiateAllMVarsCrossScope.lean new file mode 100644 index 000000000000..64245f6049af --- /dev/null +++ b/tests/lean/run/instantiateAllMVarsCrossScope.lean @@ -0,0 +1,99 @@ +import Lean +import Lean.Meta.InstMVarsAll + +open Lean Meta + +/- +Test: cross-scope sharing in scope_cache insert. + +A shared expression `succ_x := Nat.succ x_fvar` is visited at scope 1 +(as d2's argument, before scope 2 is pushed) and then at scope 2 +(inside d2's pending value). The insert optimization should reuse the +scope-1 result when inserting at scope 2, since result_scope=1 and +scope 1 hasn't changed. + + ?root := fun (a : Nat) => ?d1 a + ?d1 delayed [x] := ?body + ?body := ?d2 succ_x ← succ_x visited at scope 1 as d2's arg + ?d2 delayed [z] := ?inner + ?inner := Prod.mk z succ_x ← z = R1, succ_x visited at scope 2 + +The ordering guarantee comes from visit_delayed's control flow: args +are visited before pushing the new scope, the pending value is visited +after. This does not depend on the order in which application arguments +are traversed. + +Expected result: fun (a : Nat) => (Nat.succ a, Nat.succ a) + +With insert sharing, both Nat.succ a subexpressions in the result are +the same object (ptrEq). Without it, they are structurally equal but +distinct objects. +-/ + +private def mkCrossScopeTest : MetaM Expr := do + let nat := mkConst ``Nat + withLocalDeclD `x nat fun x_fvar => + withLocalDeclD `z nat fun z_fvar => do + let succ_x := mkApp (mkConst ``Nat.succ) x_fvar + + -- ?inner := Prod.mk z succ_x + let pairTy := mkApp2 (mkConst ``Prod [.succ .zero, .succ .zero]) nat nat + let inner ← mkFreshExprMVar pairTy + inner.mvarId!.assign + (mkApp4 (mkConst ``Prod.mk [.succ .zero, .succ .zero]) nat nat z_fvar succ_x) + + -- ?d2 delayed [z] := ?inner, takes one Nat arg + let d2_ty ← mkArrow nat pairTy + let d2 ← mkFreshExprMVar d2_ty (kind := .syntheticOpaque) + assignDelayedMVar d2.mvarId! #[z_fvar] inner.mvarId! + + -- ?body := ?d2 succ_x + let body ← mkFreshExprMVar pairTy + body.mvarId!.assign (mkApp d2 succ_x) + + -- ?d1 delayed [x] := ?body + let d1_ty ← mkArrow nat pairTy + let d1 ← mkFreshExprMVar d1_ty (kind := .syntheticOpaque) + assignDelayedMVar d1.mvarId! #[x_fvar] body.mvarId! + + -- ?root := fun (a : Nat) => ?d1 a + let rootTy ← mkArrow nat pairTy + let root ← mkFreshExprMVar rootTy + root.mvarId!.assign (Lean.mkLambda `a .default nat (mkApp d1 (.bvar 0))) + return root + +-- Expected: fun (a : Nat) => (Nat.succ a, Nat.succ a) +private def mkExpected : Expr := + let nat := mkConst ``Nat + let succ_a := mkApp (mkConst ``Nat.succ) (.bvar 0) + let body := mkApp4 (mkConst ``Prod.mk [.succ .zero, .succ .zero]) nat nat succ_a succ_a + Lean.mkLambda `a .default nat body + +-- Extract the two components from the result +-- Result shape: fun (a : Nat) => @Prod.mk Nat Nat fst snd +private def extractComponents (e : Expr) : Expr × Expr := + let body := e.bindingBody! + let snd := body.appArg! + let fst := body.appFn!.appArg! + (fst, snd) + +unsafe def checkImpl (label : String) (f : Expr → MetaM Expr) : MetaM Bool := do + let root ← mkCrossScopeTest + let expected := mkExpected + let saved ← saveState + let result ← f root + saved.restore + unless result == expected do + throwError "{label}: wrong result, got {result}" + let (fst, snd) := extractComponents result + let shared := ptrEq fst snd + IO.println s!"{label}: cross-scope sharing = {shared}" + return shared + +run_meta do + let _ ← unsafe checkImpl "instantiateMVarsOriginal" instantiateMVarsOriginal + let sharingShared ← unsafe checkImpl "instantiateAllMVarsSharing" instantiateAllMVarsSharing + let defaultShared ← unsafe checkImpl "instantiateMVars" instantiateMVars + -- instantiateAllMVarsSharing (= instantiateMVars) should have cross-scope sharing + guard sharingShared + guard defaultShared diff --git a/tests/lean/run/instantiateAllMVarsShadow.lean b/tests/lean/run/instantiateAllMVarsShadow.lean new file mode 100644 index 000000000000..575b4f9b6f0b --- /dev/null +++ b/tests/lean/run/instantiateAllMVarsShadow.lean @@ -0,0 +1,176 @@ +import Lean +import Lean.Meta.InstMVarsAll + +open Lean Meta + +/- +Test: fvar shadowing in nested delayed mvars. + +Two delayed mvars bind the same fvar `x_fvar` to different values. +A shared subexpression `succ_x := Nat.succ x_fvar` appears in both scopes. + + ?root := fun (a : Nat) => ?d1_aux #0 + ?d1_aux delayed [x_fvar] := ?body + ?body := Prod.mk succ_x (?d2_aux succ_x) ← succ_x is shared + ?d2_aux delayed [x_fvar] := ?inner + ?inner := succ_x ← same shared object + +Expected result: + fun (a : Nat) => (Nat.succ a, Nat.succ (Nat.succ a)) + +When resolving ?d1_aux with arg `a`: + - succ_x with x_fvar → a gives Nat.succ a (first component) + - ?d2_aux gets arg (Nat.succ a), so x_fvar → Nat.succ a + succ_x with x_fvar → Nat.succ a gives Nat.succ (Nat.succ a) (second component) + +A buggy cache could return the cached scope-1 result (Nat.succ a) for the scope-2 +visit, producing (Nat.succ a, Nat.succ a) instead. +-/ + +private def mkShadowTest : MetaM Expr := do + let nat := mkConst ``Nat + withLocalDeclD `x nat fun x_fvar => do + -- shared object referencing x_fvar + let succ_x := mkApp (mkConst ``Nat.succ) x_fvar + + -- ?inner := succ_x + let inner ← mkFreshExprMVar nat + inner.mvarId!.assign succ_x + + -- ?d2_aux delayed [x_fvar] := ?inner + let d2_ty ← mkArrow nat nat + let d2_aux ← mkFreshExprMVar d2_ty (kind := .syntheticOpaque) + assignDelayedMVar d2_aux.mvarId! #[x_fvar] inner.mvarId! + + -- ?body := ⟨succ_x, ?d2_aux succ_x⟩ + let pairTy := mkApp2 (mkConst ``Prod [.succ .zero, .succ .zero]) nat nat + let body ← mkFreshExprMVar pairTy + body.mvarId!.assign + (mkApp4 (mkConst ``Prod.mk [.succ .zero, .succ .zero]) nat nat + succ_x (mkApp d2_aux succ_x)) + + -- ?d1_aux delayed [x_fvar] := ?body + let d1_ty ← mkArrow nat pairTy + let d1_aux ← mkFreshExprMVar d1_ty (kind := .syntheticOpaque) + assignDelayedMVar d1_aux.mvarId! #[x_fvar] body.mvarId! + + -- ?root := fun (a : Nat) => ?d1_aux a + let rootTy ← mkArrow nat pairTy + let root ← mkFreshExprMVar rootTy + root.mvarId!.assign (Lean.mkLambda `a .default nat (mkApp d1_aux (.bvar 0))) + return root + +-- Expected: fun (a : Nat) => (Nat.succ a, Nat.succ (Nat.succ a)) +private def mkExpected : Expr := + let nat := mkConst ``Nat + let succ := mkConst ``Nat.succ + -- #0 refers to the lambda-bound `a` + let succ_a := mkApp succ (.bvar 0) + let succ_succ_a := mkApp succ succ_a + let body := mkApp4 (mkConst ``Prod.mk [.succ .zero, .succ .zero]) nat nat succ_a succ_succ_a + Lean.mkLambda `a .default nat body + +private def check (label : String) (result : Expr) : MetaM Unit := do + let expected := mkExpected + unless result == expected do + throwError "{label}: expected {expected}, got {result}" + +-- Both implementations must produce the expected result +run_meta do + let root ← mkShadowTest + + let saved ← saveState + let eOrig ← instantiateMVarsOriginal root + saved.restore + check "instantiateMVarsOriginal (shadow)" eOrig + + let saved ← saveState + let eNew ← instantiateAllMVars root + saved.restore + check "instantiateAllMVars (shadow)" eNew + +/- +Test: an fvar first seen unsubstituted, then substituted at a higher scope. + +A shared subexpression `succ_y := Nat.succ y_fvar` is used both: + - directly in the body of d1 (where y is NOT bound), and + - inside d2's pending value (where y IS bound). + + ?root := fun (a : Nat) => ?d1 a + ?d1 delayed [x] := ?body + ?body := Prod.mk succ_y (?d2 succ_y) ← succ_y shared + ?d2 delayed [y] := ?inner ← y is NOW bound + ?inner := succ_y ← same shared object + +Expected result: + fun (a : Nat) => (Nat.succ y_fvar, Nat.succ (Nat.succ y_fvar)) + +At scope 1 (d1), x → a. Visit body: + - succ_y: y is NOT in fvar_subst. Result is succ_y unchanged. + - ?d2 succ_y: arg succ_y visited → succ_y. Then d2 at scope 2 with y → succ_y. + - Visit ?inner = succ_y. y IS in fvar_subst → Nat.succ succ_y = Nat.succ (Nat.succ y_fvar). + +A buggy cache would return the scope-1 result (succ_y unchanged) at scope 2, +producing (Nat.succ y_fvar, Nat.succ y_fvar) instead. +-/ + +private def mkLateBindTest : MetaM (Expr × Expr) := do + let nat := mkConst ``Nat + withLocalDeclD `x nat fun x_fvar => + withLocalDeclD `y nat fun y_fvar => do + -- shared object referencing y_fvar (NOT x_fvar) + let succ_y := mkApp (mkConst ``Nat.succ) y_fvar + + -- ?inner := succ_y + let inner ← mkFreshExprMVar nat + inner.mvarId!.assign succ_y + + -- ?d2 delayed [y_fvar] := ?inner + let d2_ty ← mkArrow nat nat + let d2 ← mkFreshExprMVar d2_ty (kind := .syntheticOpaque) + assignDelayedMVar d2.mvarId! #[y_fvar] inner.mvarId! + + -- ?body := ⟨succ_y, ?d2 succ_y⟩ + let pairTy := mkApp2 (mkConst ``Prod [.succ .zero, .succ .zero]) nat nat + let body ← mkFreshExprMVar pairTy + body.mvarId!.assign + (mkApp4 (mkConst ``Prod.mk [.succ .zero, .succ .zero]) nat nat + succ_y (mkApp d2 succ_y)) + + -- ?d1 delayed [x_fvar] := ?body + let d1_ty ← mkArrow nat pairTy + let d1 ← mkFreshExprMVar d1_ty (kind := .syntheticOpaque) + assignDelayedMVar d1.mvarId! #[x_fvar] body.mvarId! + + -- ?root := fun (a : Nat) => ?d1 a + let rootTy ← mkArrow nat pairTy + let root ← mkFreshExprMVar rootTy + root.mvarId!.assign (Lean.mkLambda `a .default nat (mkApp d1 (.bvar 0))) + return (root, y_fvar) + +-- Expected: fun (a : Nat) => (Nat.succ y_fvar, Nat.succ (Nat.succ y_fvar)) +private def mkExpectedLateBind (y_fvar : Expr) : Expr := + let nat := mkConst ``Nat + let succ := mkConst ``Nat.succ + let succ_y := mkApp succ y_fvar + let succ_succ_y := mkApp succ succ_y + let body := mkApp4 (mkConst ``Prod.mk [.succ .zero, .succ .zero]) nat nat succ_y succ_succ_y + Lean.mkLambda `a .default nat body + +private def checkLateBind (label : String) (result : Expr) (y_fvar : Expr) : MetaM Unit := do + let expected := mkExpectedLateBind y_fvar + unless result == expected do + throwError "{label}: expected {expected}, got {result}" + +run_meta do + let (root, y_fvar) ← mkLateBindTest + + let saved ← saveState + let eOrig ← instantiateMVarsOriginal root + saved.restore + checkLateBind "instantiateMVarsOriginal (late-bind)" eOrig y_fvar + + let saved ← saveState + let eNew ← instantiateAllMVars root + saved.restore + checkLateBind "instantiateAllMVars (late-bind)" eNew y_fvar diff --git a/tests/lean/run/instantiateAllMVarsSharing.lean b/tests/lean/run/instantiateAllMVarsSharing.lean new file mode 100644 index 000000000000..85cadee3dbae --- /dev/null +++ b/tests/lean/run/instantiateAllMVarsSharing.lean @@ -0,0 +1,92 @@ +import Lean +import Lean.Meta.InstMVarsAll + +open Lean Meta + +/- +Minimal test for sharing in `instantiateAllMVars` and `instantiateAllMVarsSharing`. + +We construct the metavariable assignment graph for the goal +`∀ s, (s = s → (s = s) ∧ (s = s)) ∧ (s = s)`: + + ?root := fun (s : Nat) => ?rootAux #0 + ?rootAux delayed [s_fvar] := ?body + ?body := @And.intro leftTy rightTy ?left right + ?left := fun (h : eq_ss) => ?leftAux #0 + ?leftAux delayed [h_fvar] := ?inner + ?inner := @And.intro eq_ss eq_ss h_fvar h_fvar + +where + + eq_ss := @Eq Nat s_fvar s_fvar ← single shared object + andTy := And eq_ss eq_ss ← contains eq_ss + leftTy := eq_ss → andTy ← forallE body contains eq_ss + rightTy := eq_ss + right := @Eq.refl Nat s_fvar +-/ + +private def mkTestRoot : MetaM Expr := do + let nat := mkConst ``Nat + withLocalDeclD `s nat fun s_fvar => do + let eq_ss ← mkEq s_fvar s_fvar -- shared object + + let andTy := mkApp2 (mkConst ``And) eq_ss eq_ss -- (s=s) ∧ (s=s) + let leftTy ← mkArrow eq_ss andTy -- s=s → (s=s) ∧ (s=s) + let rightTy := eq_ss -- s=s + let bodyTy := mkApp2 (mkConst ``And) leftTy rightTy + + let body ← mkFreshExprMVar bodyTy + let left ← mkFreshExprMVar leftTy + + withLocalDeclD `h eq_ss fun h_fvar => do + -- ?inner : (s=s) ∧ (s=s), proved by And.intro eq_ss eq_ss h h + let inner ← mkFreshExprMVar andTy + let leftDecl ← left.mvarId!.getDecl + let leftAux ← mkFreshExprMVarAt leftDecl.lctx leftDecl.localInstances + leftDecl.type .syntheticOpaque + assignDelayedMVar leftAux.mvarId! #[h_fvar] inner.mvarId! + left.mvarId!.assign (Lean.mkLambda `h .default eq_ss (mkApp leftAux (.bvar 0))) + inner.mvarId!.assign (mkApp4 (mkConst ``And.intro) eq_ss eq_ss h_fvar h_fvar) + + let right := mkApp2 (mkConst ``Eq.refl [1]) nat s_fvar + body.mvarId!.assign (mkApp4 (mkConst ``And.intro) leftTy rightTy left right) + + let rootTy ← mkForallFVars #[s_fvar] bodyTy + let root ← mkFreshExprMVar rootTy + let rootDecl ← root.mvarId!.getDecl + let rootAux ← mkFreshExprMVarAt rootDecl.lctx rootDecl.localInstances + rootDecl.type .syntheticOpaque + assignDelayedMVar rootAux.mvarId! #[s_fvar] body.mvarId! + root.mvarId!.assign (Lean.mkLambda `s .default nat (mkApp rootAux (.bvar 0))) + return root + +-- Both variants now produce the same result with the same sharing +run_meta do + let root ← mkTestRoot + + let saved ← saveState + let eSharing ← instantiateAllMVarsSharing root + let nSharing ← eSharing.numObjs + saved.restore + + let saved ← saveState + let eAll ← instantiateAllMVars root + let nAll ← eAll.numObjs + saved.restore + + guard (eSharing == eAll) + guard (nAll == nSharing) + +-- instantiateAllMVarsSharing produces the same result as instantiateMVars +run_meta do + let root ← mkTestRoot + + let saved ← saveState + let eStd ← instantiateMVars root + saved.restore + + let saved ← saveState + let eSharing ← instantiateAllMVarsSharing root + saved.restore + + guard (eStd == eSharing)