diff --git a/basics/transfer-sol/anchor/programs/transfer-sol/src/lib.rs b/basics/transfer-sol/anchor/programs/transfer-sol/src/lib.rs index a2fef1afc..52c37c067 100644 --- a/basics/transfer-sol/anchor/programs/transfer-sol/src/lib.rs +++ b/basics/transfer-sol/anchor/programs/transfer-sol/src/lib.rs @@ -27,12 +27,36 @@ pub mod transfer_sol { ctx: Context, amount: u64, ) -> Result<()> { - **ctx.accounts.payer.try_borrow_mut_lamports()? -= amount; - **ctx.accounts.recipient.try_borrow_mut_lamports()? += amount; + // Security invariants: + // - The source account must authorize the transfer (is_signer) + // - The source account must be owned by this program (direct lamports mutation) + // - Prevent under/overflow on lamport arithmetic + + let payer_lamports = ctx.accounts.payer.to_account_info().lamports(); + require!(payer_lamports >= amount, TransferSolError::InsufficientFunds); + + **ctx.accounts.payer.try_borrow_mut_lamports()? = payer_lamports + .checked_sub(amount) + .ok_or(TransferSolError::LamportArithmeticOverflow)?; + + let recipient_lamports = ctx.accounts.recipient.to_account_info().lamports(); + **ctx.accounts.recipient.try_borrow_mut_lamports()? = recipient_lamports + .checked_add(amount) + .ok_or(TransferSolError::LamportArithmeticOverflow)?; + Ok(()) } } +#[error_code] +pub enum TransferSolError { + #[msg("Insufficient funds in payer account")] + InsufficientFunds, + + #[msg("Lamport arithmetic overflow/underflow")] + LamportArithmeticOverflow, +} + #[derive(Accounts)] pub struct TransferSolWithCpi<'info> { #[account(mut)] @@ -44,12 +68,14 @@ pub struct TransferSolWithCpi<'info> { #[derive(Accounts)] pub struct TransferSolWithProgram<'info> { - /// CHECK: Use owner constraint to check account is owned by our program + // NOTE: This account must sign the transaction, otherwise *anyone* could drain lamports + // from program-owned accounts passed into this instruction. #[account( mut, owner = id() // value of declare_id!() )] - payer: UncheckedAccount<'info>, + payer: Signer<'info>, + #[account(mut)] recipient: SystemAccount<'info>, } diff --git a/basics/transfer-sol/anchor/tests/test.ts b/basics/transfer-sol/anchor/tests/test.ts index 05705380a..1f8eb80a1 100644 --- a/basics/transfer-sol/anchor/tests/test.ts +++ b/basics/transfer-sol/anchor/tests/test.ts @@ -58,6 +58,7 @@ describe("Anchor: Transfer SOL", () => { payer: payerAccount.publicKey, recipient: recipientAccount.publicKey, }) + .signers([payerAccount]) .rpc(); const recipientBalance = await provider.connection.getBalance(