diff --git a/contracts/vault/src/lib.rs b/contracts/vault/src/lib.rs index e84f9700..e5bcc30b 100644 --- a/contracts/vault/src/lib.rs +++ b/contracts/vault/src/lib.rs @@ -149,6 +149,12 @@ pub enum DataKey { PriceOracle, OracleEnabled, OracleHeartbeat, + // Snapshot / checkpointing + CheckpointNonce, + CheckpointLedger(u32), + CheckpointTotalAssets(u32), + CheckpointTotalShares(u32), + CheckpointBalance(u32, Address), } #[contracttype] @@ -1354,6 +1360,84 @@ impl YieldVault { } } } + + /// Create a new checkpoint snapshot of total assets and total shares. + /// Only the Admin may call this. Returns the new checkpoint id. + pub fn create_checkpoint(env: Env) -> u32 { + let admin: Address = get_admin(&env).expect("Admin not set"); + admin.require_auth(); + + let mut next_nonce: u32 = env + .storage() + .instance() + .get(&DataKey::CheckpointNonce) + .unwrap_or(0u32); + next_nonce = next_nonce.checked_add(1).expect("overflow"); + + // Record ledger sequence for provenance + let ledger_seq = env.ledger().sequence(); + + env.storage() + .instance() + .set(&DataKey::CheckpointNonce, &next_nonce); + env.storage() + .instance() + .set(&DataKey::CheckpointLedger(next_nonce), &ledger_seq); + + // Snapshot global totals + let total_assets = Self::total_assets(env.clone()); + let total_shares = Self::total_shares(env.clone()); + env.storage() + .instance() + .set(&DataKey::CheckpointTotalAssets(next_nonce), &total_assets); + env.storage() + .instance() + .set(&DataKey::CheckpointTotalShares(next_nonce), &total_shares); + + next_nonce + } + + /// Returns the total shares recorded at a given checkpoint id. + pub fn total_shares_at(env: Env, checkpoint_id: u32) -> i128 { + env.storage() + .instance() + .get(&DataKey::CheckpointTotalShares(checkpoint_id)) + .unwrap_or(0i128) + } + + /// Returns the total assets recorded at a given checkpoint id. + pub fn total_assets_at(env: Env, checkpoint_id: u32) -> i128 { + env.storage() + .instance() + .get(&DataKey::CheckpointTotalAssets(checkpoint_id)) + .unwrap_or(0i128) + } + + /// Snapshot the caller's share balance for the latest checkpoint. + /// The caller must `require_auth` as the `user` parameter. + pub fn snapshot_user_balance(env: Env, user: Address) { + user.require_auth(); + let nonce: u32 = env + .storage() + .instance() + .get(&DataKey::CheckpointNonce) + .unwrap_or(0u32); + if nonce == 0 { + panic!("no checkpoint exists"); + } + let bal = Self::balance(env.clone(), user.clone()); + env.storage() + .instance() + .set(&DataKey::CheckpointBalance(nonce, user), &bal); + } + + /// Returns the user's snapshot balance at a given checkpoint id (0 if not recorded). + pub fn balance_at(env: Env, user: Address, checkpoint_id: u32) -> i128 { + env.storage() + .instance() + .get(&DataKey::CheckpointBalance(checkpoint_id, user)) + .unwrap_or(0i128) + } /// Read-only: returns contract metadata such as version and simple config flags. pub fn metadata(env: Env) -> ContractMetadata { let state = Self::get_state(&env); diff --git a/contracts/vault/src/test.rs b/contracts/vault/src/test.rs index 1e8fcd05..0ba58b53 100644 --- a/contracts/vault/src/test.rs +++ b/contracts/vault/src/test.rs @@ -406,6 +406,31 @@ fn test_accrue_yield_increases_total_assets() { assert_eq!(vault.total_shares(), 0); // shares unchanged. } +#[test] +fn test_checkpoint() { + let env = Env::default(); + env.mock_all_auths(); + + let (vault, usdc, usdc_sa, _admin) = setup_vault(&env); + let user = Address::generate(&env); + usdc_sa.mint(&user, &100); + + // User deposits 100 + vault.deposit(&user, &100); + + // Create a checkpoint (admin-auth in production; tests mock auth) + let cp = vault.create_checkpoint(); + assert_eq!(cp, 1); + + // Global totals should be recorded + assert_eq!(vault.total_shares_at(&cp), 100); + assert_eq!(vault.total_assets_at(&cp), 100); + + // User snapshots their balance for the checkpoint + vault.snapshot_user_balance(&user); + assert_eq!(vault.balance_at(&user, &cp), 100); +} + // ─── 5. report_benji_yield ─────────────────────────────────────────────────── #[test]