Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/controller/BaseController.sol
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,24 @@ import { IGenericShare } from "../interfaces/IGenericShare.sol";
* - Permanent contract malfunction
*/
abstract contract BaseController is AccessControlUpgradeable, ReentrancyGuardTransientUpgradeable {
// ========================================
// TRANSIENT
// ========================================

/**
* @notice Flag of deposit action in the current transaction
*/
bool internal transient _rebalanceGuardDeposited;

/**
* @notice Flag of withdraw action in the current transaction
*/
bool internal transient _rebalanceGuardWithdrawn;

// ========================================
// PERSISTENT
// ========================================

/**
* @notice Maximum basis points value representing 100%
*/
Expand Down
35 changes: 35 additions & 0 deletions src/controller/Controller.sol
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ contract Controller is
* @notice Thrown when caller is not the main vault for an asset
*/
error Controller_CallerNotMainVault();
/**
* @notice Thrown when a deposit/mint and withdrawal/redeem are attempted in the same transaction
*/
error Controller_DepositAndWithdrawInSameTx();

/**
* @notice Ensures only registered vaults can call the function
Expand All @@ -96,6 +100,33 @@ contract Controller is
require(_vaultFor[_vaultAsset(msg.sender)] == msg.sender, Controller_CallerNotMainVault());
}

/**
* @notice Enum representing actions that trigger the rebalance guard
* @dev Used to prevent deposit/mint and withdrawal/redeem operations in the same transaction
*/
enum RebalanceGuardAction {
Deposit,
Withdraw
}

/**
* @notice Ensures deposit/mint and withdrawal/redeem operations cannot occur in the same transaction
*/
modifier rebalanceGuard(RebalanceGuardAction action) {
_rebalanceGuard(action);
_;
}

function _rebalanceGuard(RebalanceGuardAction action) internal {
if (action == RebalanceGuardAction.Deposit) {
require(!_rebalanceGuardWithdrawn, Controller_DepositAndWithdrawInSameTx());
_rebalanceGuardDeposited = true;
} else if (action == RebalanceGuardAction.Withdraw) {
require(!_rebalanceGuardDeposited, Controller_DepositAndWithdrawInSameTx());
_rebalanceGuardWithdrawn = true;
}
}

/**
* @notice Constructor that disables initializers to prevent direct initialization
* @dev Uses OpenZeppelin's initializer pattern for upgradeable contracts
Expand Down Expand Up @@ -322,6 +353,7 @@ contract Controller is
public
onlyMainVault
notPaused
rebalanceGuard(RebalanceGuardAction.Deposit)
returns (uint256 shares)
{
address vault = msg.sender;
Expand Down Expand Up @@ -352,6 +384,7 @@ contract Controller is
public
onlyMainVault
notPaused
rebalanceGuard(RebalanceGuardAction.Deposit)
returns (uint256 normalizedAssets)
{
address vault = msg.sender;
Expand Down Expand Up @@ -384,6 +417,7 @@ contract Controller is
public
onlyVault
notPaused
rebalanceGuard(RebalanceGuardAction.Withdraw)
returns (uint256 shares)
{
address vault = msg.sender;
Expand Down Expand Up @@ -418,6 +452,7 @@ contract Controller is
public
onlyVault
notPaused
rebalanceGuard(RebalanceGuardAction.Withdraw)
returns (uint256 normalizedAssets)
{
address vault = msg.sender;
Expand Down
16 changes: 16 additions & 0 deletions tests/harness/ControllerHarness.sol
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,14 @@ contract ControllerHarness is Controller {
return _totalAssetsDeltaToHitProportionality(proportionalityLimit, vaultAssets, totalAssets);
}

function exposed_rebalanceGuardDeposited() external view returns (bool) {
return _rebalanceGuardDeposited;
}

function exposed_rebalanceGuardWithdrawn() external view returns (bool) {
return _rebalanceGuardWithdrawn;
}

// ========================================
// WORKAROUND FUNCTIONS
// ========================================
Expand Down Expand Up @@ -235,4 +243,12 @@ contract ControllerHarness is Controller {
assert(maxSlippage <= MAX_BPS);
maxProtocolRebalanceSlippage = uint16(maxSlippage);
}

function workaround_setRebalanceGuardDeposited(bool deposited) public {
_rebalanceGuardDeposited = deposited;
}

function workaround_setRebalanceGuardWithdrawn(bool withdrawn) public {
_rebalanceGuardWithdrawn = withdrawn;
}
}
20 changes: 20 additions & 0 deletions tests/helper/Multicall.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// SPDX-License-Identifier: MIT
pragma solidity 0.8.29;

contract Multicall {
struct Call {
address target;
bytes callData;
}

function aggregate(Call[] calldata calls) public {
for (uint256 i; i < calls.length; ++i) {
(bool success, bytes memory data) = calls[i].target.call(calls[i].callData);
if (!success) {
assembly {
revert(add(data, 0x20), mload(data))
}
}
}
}
}
49 changes: 49 additions & 0 deletions tests/integration/Controller.integration.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import { MockERC20 } from "../helper/MockERC20.sol";
import { MockPriceFeed } from "../helper/MockPriceFeed.sol";
import { MockStrategy } from "../helper/MockStrategy.sol";
import { MockSwapper } from "../helper/MockSwapper.sol";
import { Multicall } from "../helper/Multicall.sol";

abstract contract ControllerIntegrationTest is Test {
Controller controller;
Expand Down Expand Up @@ -81,6 +82,7 @@ abstract contract ControllerIntegrationTest is Test {
}

contract Controller_DepositWithdraw_IntegrationTest is ControllerIntegrationTest {
/// forge-config: default.isolate = true
function test_depositWithdraw_whenPriceStable() public {
// Deposit
vm.startPrank(user);
Expand Down Expand Up @@ -131,6 +133,7 @@ contract Controller_DepositWithdraw_IntegrationTest is ControllerIntegrationTest
assertEq(asset2.balanceOf(address(strategy2)), 50e8);
}

/// forge-config: default.isolate = true
function test_depositWithdraw_whenPriceVolatile() public {
priceFeed1.setPrice(1.1e8);
priceFeed2.setPrice(0.9e8);
Expand Down Expand Up @@ -204,6 +207,7 @@ contract Controller_DepositWithdraw_IntegrationTest is ControllerIntegrationTest
}

contract Controller_MintRedeem_IntegrationTest is ControllerIntegrationTest {
/// forge-config: default.isolate = true
function test_mintRedeem_whenPriceStable() public {
// share:asset ratio is 1:1 when price is stable

Expand Down Expand Up @@ -256,6 +260,7 @@ contract Controller_MintRedeem_IntegrationTest is ControllerIntegrationTest {
assertEq(asset2.balanceOf(address(strategy2)), 50e8);
}

/// forge-config: default.isolate = true
function test_mintRedeem_whenPriceVolatile() public {
priceFeed1.setPrice(1.1e8);
priceFeed2.setPrice(0.8e8);
Expand Down Expand Up @@ -332,6 +337,50 @@ contract Controller_MintRedeem_IntegrationTest is ControllerIntegrationTest {
}
}

contract Controller_RebalanceGuard_IntegrationTest is ControllerIntegrationTest {
/// forge-config: default.isolate = true
function test_preventDepositWithdrawInSameTx() public {
Multicall multicall = new Multicall();

Multicall.Call[] memory calls = new Multicall.Call[](3);
calls[0] = Multicall.Call(
address(asset1), abi.encodeWithSelector(asset1.approve.selector, address(vault1), type(uint256).max)
);
calls[1] = Multicall.Call(address(vault1), abi.encodeWithSelector(vault1.deposit.selector, 100e6, user));
calls[2] = Multicall.Call(address(vault1), abi.encodeWithSelector(vault1.withdraw.selector, 100e6, user, user));

vm.startPrank(user);
require(asset1.transfer(address(multicall), 100e6));
vm.expectRevert(Controller.Controller_DepositAndWithdrawInSameTx.selector);
multicall.aggregate(calls);
vm.stopPrank();
}

/// forge-config: default.isolate = true
function test_preventWithdrawDepositInSameTx() public {
Multicall multicall = new Multicall();

Multicall.Call[] memory calls1 = new Multicall.Call[](2);
calls1[0] = Multicall.Call(
address(asset1), abi.encodeWithSelector(asset1.approve.selector, address(vault1), type(uint256).max)
);
calls1[1] = Multicall.Call(address(vault1), abi.encodeWithSelector(vault1.deposit.selector, 100e6, user));

Multicall.Call[] memory calls2 = new Multicall.Call[](2);
calls2[0] = Multicall.Call(address(vault1), abi.encodeWithSelector(vault1.withdraw.selector, 100e6, user, user));
calls2[1] = Multicall.Call(address(vault1), abi.encodeWithSelector(vault1.deposit.selector, 100e6, user));

vm.startPrank(user);
require(asset1.transfer(address(multicall), 100e6));
multicall.aggregate(calls1);
gusd.approve(address(multicall), type(uint256).max);

vm.expectRevert(Controller.Controller_DepositAndWithdrawInSameTx.selector);
multicall.aggregate(calls2);
vm.stopPrank();
}
}

contract Controller_YieldDistribution_IntegrationTest is ControllerIntegrationTest {
uint256 errDelta = 0.000001e18;

Expand Down
76 changes: 76 additions & 0 deletions tests/unit/controller/Controller.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,25 @@ contract Controller_Deposit_Test is ControllerTest {
vm.prank(vault);
controller.deposit(1000e18, receiver);
}

function test_shouldSetRebalanceGuardDeposited() public {
_mockVault(vault, asset, 1, feed, 1e8, 8);

vm.prank(vault);
controller.deposit(1000e18, receiver);

assertTrue(controller.exposed_rebalanceGuardDeposited());
}

function test_shouldRevert_whenRebalanceGuardWithdrawn() public {
_mockVault(vault, asset, 1, feed, 1e8, 8);

controller.workaround_setRebalanceGuardWithdrawn(true);

vm.expectRevert(Controller.Controller_DepositAndWithdrawInSameTx.selector);
vm.prank(vault);
controller.deposit(1000e18, receiver);
}
}

contract Controller_Mint_Test is ControllerTest {
Expand Down Expand Up @@ -833,6 +852,25 @@ contract Controller_Mint_Test is ControllerTest {
vm.prank(vault);
controller.mint(1000e18, receiver);
}

function test_shouldSetRebalanceGuardDeposited() public {
_mockVault(vault, asset, 1, feed, 1e8, 8);

vm.prank(vault);
controller.mint(1000e18, receiver);

assertTrue(controller.exposed_rebalanceGuardDeposited());
}

function test_shouldRevert_whenRebalanceGuardWithdrawn() public {
_mockVault(vault, asset, 1, feed, 1e8, 8);

controller.workaround_setRebalanceGuardWithdrawn(true);

vm.expectRevert(Controller.Controller_DepositAndWithdrawInSameTx.selector);
vm.prank(vault);
controller.mint(1000e18, receiver);
}
}

contract Controller_Withdraw_Test is ControllerTest {
Expand Down Expand Up @@ -977,6 +1015,25 @@ contract Controller_Withdraw_Test is ControllerTest {
vm.prank(vault);
controller.withdraw(100e18, spender, owner);
}

function test_shouldSetRebalanceGuardDeposited() public {
_mockVault(vault, asset, 100e18, feed, 1e8, 8);

vm.prank(vault);
controller.withdraw(100e18, spender, owner);

assertTrue(controller.exposed_rebalanceGuardWithdrawn());
}

function test_shouldRevert_whenRebalanceGuardWithdrawn() public {
_mockVault(vault, asset, 100e18, feed, 1e8, 8);

controller.workaround_setRebalanceGuardDeposited(true);

vm.expectRevert(Controller.Controller_DepositAndWithdrawInSameTx.selector);
vm.prank(vault);
controller.withdraw(100e18, spender, owner);
}
}

contract Controller_Redeem_Test is ControllerTest {
Expand Down Expand Up @@ -1128,4 +1185,23 @@ contract Controller_Redeem_Test is ControllerTest {
vm.prank(vault);
controller.redeem(100e18, spender, owner);
}

function test_shouldSetRebalanceGuardDeposited() public {
_mockVault(vault, asset, 100e18, feed, 1e8, 8);

vm.prank(vault);
controller.redeem(100e18, spender, owner);

assertTrue(controller.exposed_rebalanceGuardWithdrawn());
}

function test_shouldRevert_whenRebalanceGuardWithdrawn() public {
_mockVault(vault, asset, 100e18, feed, 1e8, 8);

controller.workaround_setRebalanceGuardDeposited(true);

vm.expectRevert(Controller.Controller_DepositAndWithdrawInSameTx.selector);
vm.prank(vault);
controller.redeem(100e18, spender, owner);
}
}