From 822694a26bea3d490f687492c86e2591a71990c3 Mon Sep 17 00:00:00 2001 From: lucas-manuel Date: Sat, 18 May 2024 11:02:33 -0400 Subject: [PATCH 01/16] feat: first test working --- src/PSM.sol | 55 ++++++++++++++++++++++++++++++---- test/Constructor.t.sol | 8 ++--- test/Getters.t.sol | 2 +- test/InflationAttack.t.sol | 56 +++++++++++++++++++++++++++++++++++ test/PSMTestBase.sol | 4 +-- test/harnesses/PSMHarness.sol | 4 +-- 6 files changed, 115 insertions(+), 14 deletions(-) create mode 100644 test/InflationAttack.t.sol diff --git a/src/PSM.sol b/src/PSM.sol index be49033..e43e52f 100644 --- a/src/PSM.sol +++ b/src/PSM.sol @@ -1,6 +1,8 @@ // SPDX-License-Identifier: AGPL-3.0-or-later pragma solidity ^0.8.13; +import { console2 } from "forge-std/console2.sol"; + import { IERC20 } from "erc20-helpers/interfaces/IERC20.sol"; import { SafeERC20 } from "erc20-helpers/SafeERC20.sol"; @@ -30,12 +32,18 @@ contract PSM { uint256 public immutable asset0Precision; uint256 public immutable asset1Precision; + uint256 public immutable initialBurnAmount; uint256 public totalShares; mapping(address user => uint256 shares) public shares; - constructor(address asset0_, address asset1_, address rateProvider_) { + constructor( + address asset0_, + address asset1_, + address rateProvider_, + uint256 initialBurnAmount_ + ) { require(asset0_ != address(0), "PSM/invalid-asset0"); require(asset1_ != address(0), "PSM/invalid-asset1"); require(rateProvider_ != address(0), "PSM/invalid-rateProvider"); @@ -44,8 +52,9 @@ contract PSM { asset1 = IERC20(asset1_); rateProvider = rateProvider_; - asset0Precision = 10 ** IERC20(asset0_).decimals(); - asset1Precision = 10 ** IERC20(asset1_).decimals(); + asset0Precision = 10 ** IERC20(asset0_).decimals(); + asset1Precision = 10 ** IERC20(asset1_).decimals(); + initialBurnAmount = initialBurnAmount_; } /**********************************************************************************************/ @@ -86,6 +95,13 @@ contract PSM { // Convert amount to 1e18 precision denominated in value of asset0 then convert to shares. uint256 newShares = convertToShares(_getAssetValue(asset, assetsToDeposit)); + if (totalShares == 0 && initialBurnAmount != 0) { + shares[address(0)] += initialBurnAmount; + totalShares += initialBurnAmount; + + newShares -= initialBurnAmount; + } + shares[msg.sender] += newShares; totalShares += newShares; @@ -103,13 +119,21 @@ contract PSM { ? assetBalance : maxAssetsToWithdraw; - uint256 sharesToBurn = convertToShares(asset, assetsWithdrawn); + console2.log("assetsWithdrawn", assetsWithdrawn); + console2.log("assetBalance ", assetBalance); + + uint256 sharesToBurn = convertToSharesRoundUp(asset, assetsWithdrawn); + + console2.log("sharesToBurn ", sharesToBurn); if (sharesToBurn > shares[msg.sender]) { assetsWithdrawn = convertToAssets(asset, shares[msg.sender]); - sharesToBurn = convertToShares(asset, assetsWithdrawn); + sharesToBurn = convertToSharesRoundUp(asset, assetsWithdrawn); } + console2.log("assetsWithdrawn", assetsWithdrawn); + console2.log("sharesToBurn ", sharesToBurn); + unchecked { shares[msg.sender] -= sharesToBurn; totalShares -= sharesToBurn; @@ -122,19 +146,36 @@ contract PSM { /*** Conversion functions ***/ /**********************************************************************************************/ + // TODO: Refactor to use better naming function convertToShares(uint256 assetValue) public view returns (uint256) { uint256 totalValue = getPsmTotalValue(); + console2.log("totalValue ", totalValue); + console2.log("totalShares", totalShares); + console2.log("assetValue ", assetValue); if (totalValue != 0) { return assetValue * totalShares / totalValue; } return assetValue; } + function convertToSharesRoundUp(uint256 assetValue) public view returns (uint256) { + uint256 totalValue = getPsmTotalValue(); + if (totalValue != 0) { + return _divRoundUp(assetValue * totalShares, totalValue); + } + return assetValue; + } + function convertToShares(address asset, uint256 assets) public view returns (uint256) { require(asset == address(asset0) || asset == address(asset1), "PSM/invalid-asset"); return convertToShares(_getAssetValue(asset, assets)); } + function convertToSharesRoundUp(address asset, uint256 assets) public view returns (uint256) { + require(asset == address(asset0) || asset == address(asset1), "PSM/invalid-asset"); + return convertToSharesRoundUp(_getAssetValue(asset, assets)); + } + function convertToAssetValue(uint256 numShares) public view returns (uint256) { uint256 totalShares_ = totalShares; @@ -182,6 +223,10 @@ contract PSM { /*** Internal helper functions ***/ /**********************************************************************************************/ + function _divRoundUp(uint256 numerator_, uint256 divisor_) internal pure returns (uint256 result_) { + result_ = (numerator_ + divisor_ - 1) / divisor_; + } + function _getAssetValue(address asset, uint256 amount) internal view returns (uint256) { if (asset == address(asset0)) { return _getAsset0Value(amount); diff --git a/test/Constructor.t.sol b/test/Constructor.t.sol index cdec19d..d505441 100644 --- a/test/Constructor.t.sol +++ b/test/Constructor.t.sol @@ -11,22 +11,22 @@ contract PSMConstructorTests is PSMTestBase { function test_constructor_invalidAsset0() public { vm.expectRevert("PSM/invalid-asset0"); - new PSM(address(0), address(sDai), address(rateProvider)); + new PSM(address(0), address(sDai), address(rateProvider), 1000); } function test_constructor_invalidAsset1() public { vm.expectRevert("PSM/invalid-asset1"); - new PSM(address(usdc), address(0), address(rateProvider)); + new PSM(address(usdc), address(0), address(rateProvider), 1000); } function test_constructor_invalidRateProvider() public { vm.expectRevert("PSM/invalid-rateProvider"); - new PSM(address(sDai), address(usdc), address(0)); + new PSM(address(sDai), address(usdc), address(0), 1000); } function test_constructor() public { // Deploy new PSM to get test coverage - psm = new PSM(address(usdc), address(sDai), address(rateProvider)); + psm = new PSM(address(usdc), address(sDai), address(rateProvider), 1000); assertEq(address(psm.asset0()), address(usdc)); assertEq(address(psm.asset1()), address(sDai)); diff --git a/test/Getters.t.sol b/test/Getters.t.sol index 911ff91..895f445 100644 --- a/test/Getters.t.sol +++ b/test/Getters.t.sol @@ -13,7 +13,7 @@ contract PSMHarnessTests is PSMTestBase { function setUp() public override { super.setUp(); - psmHarness = new PSMHarness(address(usdc), address(sDai), address(rateProvider)); + psmHarness = new PSMHarness(address(usdc), address(sDai), address(rateProvider), 1000); } function test_getAsset0Value() public view { diff --git a/test/InflationAttack.t.sol b/test/InflationAttack.t.sol new file mode 100644 index 0000000..3bdcbe6 --- /dev/null +++ b/test/InflationAttack.t.sol @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later +pragma solidity ^0.8.13; + +import "forge-std/Test.sol"; + +import { PSM } from "src/PSM.sol"; + +import { PSMTestBase } from "test/PSMTestBase.sol"; + +contract InflationAttackTests is PSMTestBase { + + function test_inflationAttack_noInitialBurnAmount() public { + psm = new PSM(address(usdc), address(sDai), address(rateProvider), 0); + + address firstDepositor = makeAddr("firstDepositor"); + address frontRunner = makeAddr("frontRunner"); + + // Step 1: Front runner deposits 1 sDAI to get 1 share + + // Have to use sDai because 1 USDC mints 1e12 shares + _deposit(frontRunner, address(sDai), 1); + + assertEq(psm.shares(frontRunner), 1); + + // Step 2: Front runner transfers 1m USDC to inflate the exchange rate to 1:(1m + 1) + + deal(address(usdc), frontRunner, 1_000_000e6); + + vm.prank(frontRunner); + usdc.transfer(address(psm), 1_000_000e6); + + // Highly inflated exchange rate + assertEq(psm.convertToAssetValue(1), 1_000_000e18 + 1); + + // Step 3: First depositor deposits 2 million USDC, only gets one share because rounding + // error gives them 1 instead of 2 shares, worth 1.5m USDC + + _deposit(firstDepositor, address(usdc), 2_000_000e6); + + assertEq(psm.shares(firstDepositor), 1); + + // 1 share = 3 million USDC / 2 shares = 1.5 million USDC + assertEq(psm.convertToAssetValue(1), 1_500_000e18); + + // Step 4: Both users withdraw the max amount of funds they can + + _withdraw(firstDepositor, address(usdc), type(uint256).max); + _withdraw(frontRunner, address(usdc), type(uint256).max); + + assertEq(usdc.balanceOf(address(psm)), 0); + + // Front runner profits 500k USDC, first depositor loses 500k USDC + assertEq(usdc.balanceOf(firstDepositor), 1_500_000e6); + assertEq(usdc.balanceOf(frontRunner), 1_500_000e6); + } +} diff --git a/test/PSMTestBase.sol b/test/PSMTestBase.sol index 23cc97c..8675d7c 100644 --- a/test/PSMTestBase.sol +++ b/test/PSMTestBase.sol @@ -3,7 +3,7 @@ pragma solidity ^0.8.13; import "forge-std/Test.sol"; -import { PSM } from "../src/PSM.sol"; +import { PSM } from "src/PSM.sol"; import { MockERC20 } from "erc20-helpers/MockERC20.sol"; @@ -38,7 +38,7 @@ contract PSMTestBase is Test { // NOTE: Using 1.25 for easy two way conversions rateProvider.__setConversionRate(1.25e27); - psm = new PSM(address(usdc), address(sDai), address(rateProvider)); + psm = new PSM(address(usdc), address(sDai), address(rateProvider), 1000); vm.label(address(sDai), "sDAI"); vm.label(address(usdc), "USDC"); diff --git a/test/harnesses/PSMHarness.sol b/test/harnesses/PSMHarness.sol index 6a5eaa7..48dfcf9 100644 --- a/test/harnesses/PSMHarness.sol +++ b/test/harnesses/PSMHarness.sol @@ -5,8 +5,8 @@ import { PSM } from "src/PSM.sol"; contract PSMHarness is PSM { - constructor(address asset0_, address asset1_, address rateProvider_) - PSM(asset0_, asset1_, rateProvider_) {} + constructor(address asset0_, address asset1_, address rateProvider_, uint256 initialBurnAmount_) + PSM(asset0_, asset1_, rateProvider_, initialBurnAmount_) {} function getAssetValue(address asset, uint256 amount) external view returns (uint256) { return _getAssetValue(asset, amount); From d6530e1f275b5d3f61f35dfe7f9f4f7287ff4942 Mon Sep 17 00:00:00 2001 From: lucas-manuel Date: Sat, 18 May 2024 11:03:52 -0400 Subject: [PATCH 02/16] feat: use larger numbers: --- test/InflationAttack.t.sol | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/InflationAttack.t.sol b/test/InflationAttack.t.sol index 3bdcbe6..7214651 100644 --- a/test/InflationAttack.t.sol +++ b/test/InflationAttack.t.sol @@ -22,25 +22,25 @@ contract InflationAttackTests is PSMTestBase { assertEq(psm.shares(frontRunner), 1); - // Step 2: Front runner transfers 1m USDC to inflate the exchange rate to 1:(1m + 1) + // Step 2: Front runner transfers 10m USDC to inflate the exchange rate to 1:(10m + 1) - deal(address(usdc), frontRunner, 1_000_000e6); + deal(address(usdc), frontRunner, 10_000_000e6); vm.prank(frontRunner); - usdc.transfer(address(psm), 1_000_000e6); + usdc.transfer(address(psm), 10_000_000e6); // Highly inflated exchange rate - assertEq(psm.convertToAssetValue(1), 1_000_000e18 + 1); + assertEq(psm.convertToAssetValue(1), 10_000_000e18 + 1); - // Step 3: First depositor deposits 2 million USDC, only gets one share because rounding - // error gives them 1 instead of 2 shares, worth 1.5m USDC + // Step 3: First depositor deposits 20 million USDC, only gets one share because rounding + // error gives them 1 instead of 2 shares, worth 15m USDC - _deposit(firstDepositor, address(usdc), 2_000_000e6); + _deposit(firstDepositor, address(usdc), 20_000_000e6); assertEq(psm.shares(firstDepositor), 1); // 1 share = 3 million USDC / 2 shares = 1.5 million USDC - assertEq(psm.convertToAssetValue(1), 1_500_000e18); + assertEq(psm.convertToAssetValue(1), 15_000_000e18); // Step 4: Both users withdraw the max amount of funds they can @@ -49,8 +49,8 @@ contract InflationAttackTests is PSMTestBase { assertEq(usdc.balanceOf(address(psm)), 0); - // Front runner profits 500k USDC, first depositor loses 500k USDC - assertEq(usdc.balanceOf(firstDepositor), 1_500_000e6); - assertEq(usdc.balanceOf(frontRunner), 1_500_000e6); + // Front runner profits 5m USDC, first depositor loses 5m USDC + assertEq(usdc.balanceOf(firstDepositor), 15_000_000e6); + assertEq(usdc.balanceOf(frontRunner), 15_000_000e6); } } From 8f11df43f407314ab3ca67946713e27496a656ee Mon Sep 17 00:00:00 2001 From: lucas-manuel Date: Sat, 18 May 2024 11:18:59 -0400 Subject: [PATCH 03/16] feat: test with initial burn amount passing --- test/InflationAttack.t.sol | 64 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/test/InflationAttack.t.sol b/test/InflationAttack.t.sol index 7214651..a21999c 100644 --- a/test/InflationAttack.t.sol +++ b/test/InflationAttack.t.sol @@ -53,4 +53,68 @@ contract InflationAttackTests is PSMTestBase { assertEq(usdc.balanceOf(firstDepositor), 15_000_000e6); assertEq(usdc.balanceOf(frontRunner), 15_000_000e6); } + + function test_inflationAttack_useInitialBurnAmount_firstDepositOverflowBoundary() public { + psm = new PSM(address(usdc), address(sDai), address(rateProvider), 1000); + + address frontRunner = makeAddr("frontRunner"); + + vm.startPrank(frontRunner); + sDai.mint(frontRunner, 800); + sDai.approve(address(psm), 800); + + vm.expectRevert(stdError.arithmeticError); + psm.deposit(address(sDai), 799); + + // 800 sDAI = 1000 shares + psm.deposit(address(sDai), 800); + } + + function test_inflationAttack_useInitialBurnAmount() public { + psm = new PSM(address(usdc), address(sDai), address(rateProvider), 1000); + + address firstDepositor = makeAddr("firstDepositor"); + address frontRunner = makeAddr("frontRunner"); + + // Step 1: Front runner deposits 801 sDAI to get 1 share + + // 1000 shares get burned, user is left with 1 + _deposit(frontRunner, address(sDai), 801); + + assertEq(psm.shares(frontRunner), 1); + + // Step 2: Front runner transfers 10m USDC to inflate the exchange rate to 1:(10m + 1) + + deal(address(usdc), frontRunner, 10_000_000e6); + + vm.prank(frontRunner); + usdc.transfer(address(psm), 10_000_000e6); + + // Much less inflated exchange rate + assertEq(psm.convertToAssetValue(1), 9990.009990009990009991e18); + + // Step 3: First depositor deposits 20 million USDC, only gets one share because rounding + // error gives them 1 instead of 2 shares, worth 15m USDC + + _deposit(firstDepositor, address(usdc), 20_000_000e6); + + assertEq(psm.shares(firstDepositor), 2001); + + // Higher amount of initial shares means lower rounding error + assertEq(psm.convertToAssetValue(2001), 19_996_668.887408394403731513e18); + + // Step 4: Both users withdraw the max amount of funds they can + + _withdraw(firstDepositor, address(usdc), type(uint256).max); + _withdraw(frontRunner, address(usdc), type(uint256).max); + + // Burnt shares have a claim on these + // TODO: Should this be an admin contract instead of address(0)? + assertEq(usdc.balanceOf(address(psm)), 9_993_337.774818e6); + + // Front runner loses 999k USDC, first depositor loses 4k USDC + assertEq(usdc.balanceOf(firstDepositor), 19_996_668.887408e6); + assertEq(usdc.balanceOf(frontRunner), 9_993.337774e6); + } + } From 78fb4ad59a8b2d8a14a369f105f9cddf0d5b7d14 Mon Sep 17 00:00:00 2001 From: lucas-manuel Date: Tue, 21 May 2024 10:47:48 -0400 Subject: [PATCH 04/16] feat: update tests to work with updated burn logic, move conversion functions around and use previews --- src/PSM.sol | 145 +++++++++++++++++++------------------ test/Deposit.t.sol | 40 ++++++---- test/InflationAttack.t.sol | 2 +- test/Withdraw.t.sol | 79 +++++++++++++------- 4 files changed, 152 insertions(+), 114 deletions(-) diff --git a/src/PSM.sol b/src/PSM.sol index e43e52f..604dfa1 100644 --- a/src/PSM.sol +++ b/src/PSM.sol @@ -89,11 +89,10 @@ contract PSM { /*** Liquidity provision functions ***/ /**********************************************************************************************/ - function deposit(address asset, uint256 assetsToDeposit) external { - require(asset == address(asset0) || asset == address(asset1), "PSM/invalid-asset"); - - // Convert amount to 1e18 precision denominated in value of asset0 then convert to shares. - uint256 newShares = convertToShares(_getAssetValue(asset, assetsToDeposit)); + function deposit(address asset, uint256 assetsToDeposit) + external returns (uint256 newShares) + { + newShares = previewDeposit(asset, assetsToDeposit); if (totalShares == 0 && initialBurnAmount != 0) { shares[address(0)] += initialBurnAmount; @@ -110,6 +109,32 @@ contract PSM { function withdraw(address asset, uint256 maxAssetsToWithdraw) external returns (uint256 assetsWithdrawn) + { + uint256 sharesToBurn; + + ( sharesToBurn, assetsWithdrawn ) = previewWithdraw(asset, maxAssetsToWithdraw); + + unchecked { + shares[msg.sender] -= sharesToBurn; + totalShares -= sharesToBurn; + } + + IERC20(asset).safeTransfer(msg.sender, assetsWithdrawn); + } + + /**********************************************************************************************/ + /*** Deposit/withdraw preview functions ***/ + /**********************************************************************************************/ + + function previewDeposit(address asset, uint256 assets) public view returns (uint256) { + require(asset == address(asset0) || asset == address(asset1), "PSM/invalid-asset"); + + // Convert amount to 1e18 precision denominated in value of asset0 then convert to shares. + return convertToShares(_getAssetValue(asset, assets)); + } + + function previewWithdraw(address asset, uint256 maxAssetsToWithdraw) + public view returns (uint256 sharesToBurn, uint256 assetsWithdrawn) { require(asset == address(asset0) || asset == address(asset1), "PSM/invalid-asset"); @@ -119,61 +144,41 @@ contract PSM { ? assetBalance : maxAssetsToWithdraw; - console2.log("assetsWithdrawn", assetsWithdrawn); - console2.log("assetBalance ", assetBalance); - - uint256 sharesToBurn = convertToSharesRoundUp(asset, assetsWithdrawn); - - console2.log("sharesToBurn ", sharesToBurn); + sharesToBurn = _convertToSharesRoundUp(_getAssetValue(asset, assetsWithdrawn)); if (sharesToBurn > shares[msg.sender]) { assetsWithdrawn = convertToAssets(asset, shares[msg.sender]); - sharesToBurn = convertToSharesRoundUp(asset, assetsWithdrawn); + sharesToBurn = _convertToSharesRoundUp(_getAssetValue(asset, assetsWithdrawn)); // TODO: This can cause an underflow, refactor to use full shares balance? } - - console2.log("assetsWithdrawn", assetsWithdrawn); - console2.log("sharesToBurn ", sharesToBurn); - - unchecked { - shares[msg.sender] -= sharesToBurn; - totalShares -= sharesToBurn; - } - - IERC20(asset).safeTransfer(msg.sender, assetsWithdrawn); } /**********************************************************************************************/ - /*** Conversion functions ***/ + /*** Swap preview functions ***/ /**********************************************************************************************/ - // TODO: Refactor to use better naming - function convertToShares(uint256 assetValue) public view returns (uint256) { - uint256 totalValue = getPsmTotalValue(); - console2.log("totalValue ", totalValue); - console2.log("totalShares", totalShares); - console2.log("assetValue ", assetValue); - if (totalValue != 0) { - return assetValue * totalShares / totalValue; - } - return assetValue; + function previewSwapAssetZeroToOne(uint256 amountIn) public view returns (uint256) { + return amountIn + * 1e27 + * asset1Precision + / IRateProviderLike(rateProvider).getConversionRate() + / asset0Precision; } - function convertToSharesRoundUp(uint256 assetValue) public view returns (uint256) { - uint256 totalValue = getPsmTotalValue(); - if (totalValue != 0) { - return _divRoundUp(assetValue * totalShares, totalValue); - } - return assetValue; + function previewSwapAssetOneToZero(uint256 amountIn) public view returns (uint256) { + return amountIn + * IRateProviderLike(rateProvider).getConversionRate() + * asset0Precision + / 1e27 + / asset1Precision; } - function convertToShares(address asset, uint256 assets) public view returns (uint256) { - require(asset == address(asset0) || asset == address(asset1), "PSM/invalid-asset"); - return convertToShares(_getAssetValue(asset, assets)); - } + /**********************************************************************************************/ + /*** Conversion functions ***/ + /**********************************************************************************************/ - function convertToSharesRoundUp(address asset, uint256 assets) public view returns (uint256) { + function convertToAssets(address asset, uint256 numShares) public view returns (uint256) { require(asset == address(asset0) || asset == address(asset1), "PSM/invalid-asset"); - return convertToSharesRoundUp(_getAssetValue(asset, assets)); + return _getAssetsByValue(asset, convertToAssetValue(numShares)); } function convertToAssetValue(uint256 numShares) public view returns (uint256) { @@ -185,9 +190,17 @@ contract PSM { return numShares; } - function convertToAssets(address asset, uint256 numShares) public view returns (uint256) { + function convertToShares(uint256 assetValue) public view returns (uint256) { + uint256 totalValue = getPsmTotalValue(); + if (totalValue != 0) { + return assetValue * totalShares / totalValue; + } + return assetValue; + } + + function convertToShares(address asset, uint256 assets) public view returns (uint256) { require(asset == address(asset0) || asset == address(asset1), "PSM/invalid-asset"); - return _getAssetsByValue(asset, convertToAssetValue(numShares)); + return convertToShares(_getAssetValue(asset, assets)); } /**********************************************************************************************/ @@ -200,39 +213,27 @@ contract PSM { } /**********************************************************************************************/ - /*** Swap preview functions ***/ + /*** Internal helper functions ***/ /**********************************************************************************************/ - function previewSwapAssetZeroToOne(uint256 amountIn) public view returns (uint256) { - return amountIn - * 1e27 - * asset1Precision - / IRateProviderLike(rateProvider).getConversionRate() - / asset0Precision; - } - - function previewSwapAssetOneToZero(uint256 amountIn) public view returns (uint256) { - return amountIn - * IRateProviderLike(rateProvider).getConversionRate() - * asset0Precision - / 1e27 - / asset1Precision; + function _convertToSharesRoundUp(uint256 assetValue) internal view returns (uint256) { + uint256 totalValue = getPsmTotalValue(); + if (totalValue != 0) { + return _divRoundUp(assetValue * totalShares, totalValue); + } + return assetValue; } - /**********************************************************************************************/ - /*** Internal helper functions ***/ - /**********************************************************************************************/ - - function _divRoundUp(uint256 numerator_, uint256 divisor_) internal pure returns (uint256 result_) { + function _divRoundUp(uint256 numerator_, uint256 divisor_) + internal pure returns (uint256 result_) + { result_ = (numerator_ + divisor_ - 1) / divisor_; } function _getAssetValue(address asset, uint256 amount) internal view returns (uint256) { - if (asset == address(asset0)) { - return _getAsset0Value(amount); - } - - return _getAsset1Value(amount); + return asset == address(asset0) + ? _getAsset0Value(amount) + : _getAsset1Value(amount); } function _getAssetsByValue(address asset, uint256 assetValue) internal view returns (uint256) { diff --git a/test/Deposit.t.sol b/test/Deposit.t.sol index 24ce76a..f9924f8 100644 --- a/test/Deposit.t.sol +++ b/test/Deposit.t.sol @@ -11,6 +11,9 @@ contract PSMDepositTests is PSMTestBase { address user1 = makeAddr("user1"); address user2 = makeAddr("user2"); + address burn = address(0); + + uint256 BURN_AMOUNT = 1000; function test_deposit_notAsset0OrAsset1() public { vm.expectRevert("PSM/invalid-asset"); @@ -32,6 +35,7 @@ contract PSMDepositTests is PSMTestBase { assertEq(psm.totalShares(), 0); assertEq(psm.shares(user1), 0); + assertEq(psm.shares(burn), 0); assertEq(psm.convertToShares(1e18), 1e18); @@ -42,7 +46,8 @@ contract PSMDepositTests is PSMTestBase { assertEq(usdc.balanceOf(address(psm)), 100e6); assertEq(psm.totalShares(), 100e18); - assertEq(psm.shares(user1), 100e18); + assertEq(psm.shares(user1), 100e18 - BURN_AMOUNT); + assertEq(psm.shares(burn), BURN_AMOUNT); assertEq(psm.convertToShares(1e18), 1e18); } @@ -60,6 +65,7 @@ contract PSMDepositTests is PSMTestBase { assertEq(psm.totalShares(), 0); assertEq(psm.shares(user1), 0); + assertEq(psm.shares(burn), 0); assertEq(psm.convertToShares(1e18), 1e18); @@ -70,7 +76,8 @@ contract PSMDepositTests is PSMTestBase { assertEq(sDai.balanceOf(address(psm)), 100e18); assertEq(psm.totalShares(), 125e18); - assertEq(psm.shares(user1), 125e18); + assertEq(psm.shares(user1), 125e18 - BURN_AMOUNT); + assertEq(psm.shares(burn), BURN_AMOUNT); assertEq(psm.convertToShares(1e18), 1e18); } @@ -94,7 +101,8 @@ contract PSMDepositTests is PSMTestBase { assertEq(sDai.balanceOf(address(psm)), 0); assertEq(psm.totalShares(), 100e18); - assertEq(psm.shares(user1), 100e18); + assertEq(psm.shares(user1), 100e18 - BURN_AMOUNT); + assertEq(psm.shares(burn), BURN_AMOUNT); assertEq(psm.convertToShares(1e18), 1e18); @@ -107,14 +115,16 @@ contract PSMDepositTests is PSMTestBase { assertEq(sDai.balanceOf(address(psm)), 100e18); assertEq(psm.totalShares(), 225e18); - assertEq(psm.shares(user1), 225e18); + assertEq(psm.shares(user1), 225e18 - BURN_AMOUNT); + assertEq(psm.shares(burn), BURN_AMOUNT); // Only burn on first deposit assertEq(psm.convertToShares(1e18), 1e18); } function testFuzz_deposit_usdcThenSDai(uint256 usdcAmount, uint256 sDaiAmount) public { - usdcAmount = _bound(usdcAmount, 0, USDC_TOKEN_MAX); - sDaiAmount = _bound(sDaiAmount, 0, SDAI_TOKEN_MAX); + // NOTE: Deposits revert if deposit amount is less than the burn amount + usdcAmount = _bound(usdcAmount, BURN_AMOUNT, USDC_TOKEN_MAX); + sDaiAmount = _bound(sDaiAmount, BURN_AMOUNT, SDAI_TOKEN_MAX); usdc.mint(user1, usdcAmount); @@ -134,7 +144,8 @@ contract PSMDepositTests is PSMTestBase { assertEq(sDai.balanceOf(address(psm)), 0); assertEq(psm.totalShares(), usdcAmount * 1e12); - assertEq(psm.shares(user1), usdcAmount * 1e12); + assertEq(psm.shares(user1), usdcAmount * 1e12 - BURN_AMOUNT); + assertEq(psm.shares(burn), BURN_AMOUNT); assertEq(psm.convertToShares(1e18), 1e18); @@ -147,7 +158,8 @@ contract PSMDepositTests is PSMTestBase { assertEq(sDai.balanceOf(address(psm)), sDaiAmount); assertEq(psm.totalShares(), usdcAmount * 1e12 + sDaiAmount * 125/100); - assertEq(psm.shares(user1), usdcAmount * 1e12 + sDaiAmount * 125/100); + assertEq(psm.shares(user1), usdcAmount * 1e12 + sDaiAmount * 125/100 - BURN_AMOUNT); + assertEq(psm.shares(burn), BURN_AMOUNT); // Only burn on first deposit assertEq(psm.convertToShares(1e18), 1e18); } @@ -175,11 +187,12 @@ contract PSMDepositTests is PSMTestBase { assertEq(sDai.balanceOf(address(psm)), 100e18); assertEq(psm.totalShares(), 225e18); - assertEq(psm.shares(user1), 225e18); + assertEq(psm.shares(user1), 225e18 - BURN_AMOUNT); + assertEq(psm.shares(burn), BURN_AMOUNT); assertEq(psm.convertToShares(1e18), 1e18); - assertEq(psm.convertToAssetValue(psm.shares(user1)), 225e18); + assertEq(psm.convertToAssetValue(psm.shares(user1)), 225e18 - BURN_AMOUNT); rateProvider.__setConversionRate(1.5e27); @@ -199,7 +212,7 @@ contract PSMDepositTests is PSMTestBase { assertEq(sDai.balanceOf(user2), 100e18); assertEq(sDai.balanceOf(address(psm)), 100e18); - assertEq(psm.convertToAssetValue(psm.shares(user1)), 250e18); + assertEq(psm.convertToAssetValue(psm.shares(user1)), 250e18 - 1112); // Burn amount conversion assertEq(psm.convertToAssetValue(psm.shares(user2)), 0); assertEq(psm.getPsmTotalValue(), 250e18); @@ -218,11 +231,12 @@ contract PSMDepositTests is PSMTestBase { assertEq(expectedShares, 135e18); assertEq(psm.totalShares(), 360e18); - assertEq(psm.shares(user1), 225e18); + assertEq(psm.shares(user1), 225e18 - BURN_AMOUNT); assertEq(psm.shares(user2), 135e18); + assertEq(psm.shares(burn), BURN_AMOUNT); // User 1 earned $25 on 225, User 2 has earned nothing - assertEq(psm.convertToAssetValue(psm.shares(user1)), 250e18); + assertEq(psm.convertToAssetValue(psm.shares(user1)), 250e18 - 1112); assertEq(psm.convertToAssetValue(psm.shares(user2)), 150e18); assertEq(psm.getPsmTotalValue(), 400e18); diff --git a/test/InflationAttack.t.sol b/test/InflationAttack.t.sol index a21999c..9f0e189 100644 --- a/test/InflationAttack.t.sol +++ b/test/InflationAttack.t.sol @@ -112,7 +112,7 @@ contract InflationAttackTests is PSMTestBase { // TODO: Should this be an admin contract instead of address(0)? assertEq(usdc.balanceOf(address(psm)), 9_993_337.774818e6); - // Front runner loses 999k USDC, first depositor loses 4k USDC + // Front runner loses 9.99m USDC, first depositor loses 4k USDC assertEq(usdc.balanceOf(firstDepositor), 19_996_668.887408e6); assertEq(usdc.balanceOf(frontRunner), 9_993.337774e6); } diff --git a/test/Withdraw.t.sol b/test/Withdraw.t.sol index 50eeb78..421db3f 100644 --- a/test/Withdraw.t.sol +++ b/test/Withdraw.t.sol @@ -13,6 +13,9 @@ contract PSMWithdrawTests is PSMTestBase { address user1 = makeAddr("user1"); address user2 = makeAddr("user2"); + address burn = address(0); + + uint256 BURN_AMOUNT = 1000; function test_withdraw_notAsset0OrAsset1() public { vm.expectRevert("PSM/invalid-asset"); @@ -28,20 +31,24 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(usdc.balanceOf(address(psm)), 100e6); assertEq(psm.totalShares(), 100e18); - assertEq(psm.shares(user1), 100e18); + assertEq(psm.shares(user1), 100e18 - BURN_AMOUNT); + assertEq(psm.shares(burn), BURN_AMOUNT); assertEq(psm.convertToShares(1e18), 1e18); vm.prank(user1); uint256 amount = psm.withdraw(address(usdc), 100e6); - assertEq(amount, 100e6); + // Burn amount causes shares to round down by one since shares are 99.999... + assertEq(amount, 100e6 - 1); - assertEq(usdc.balanceOf(user1), 100e6); - assertEq(usdc.balanceOf(address(psm)), 0); + assertEq(usdc.balanceOf(user1), 100e6 - 1); + assertEq(usdc.balanceOf(address(psm)), 1); - assertEq(psm.totalShares(), 0); - assertEq(psm.shares(user1), 0); + // User still has left over shares from rounding on 1e6 + assertEq(psm.totalShares(), 1e12); + assertEq(psm.shares(user1), 1e12 - BURN_AMOUNT); + assertEq(psm.shares(burn), BURN_AMOUNT); assertEq(psm.convertToShares(1e18), 1e18); } @@ -53,20 +60,25 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(sDai.balanceOf(address(psm)), 80e18); assertEq(psm.totalShares(), 100e18); - assertEq(psm.shares(user1), 100e18); + assertEq(psm.shares(user1), 100e18 - BURN_AMOUNT); + assertEq(psm.shares(burn), BURN_AMOUNT); + + // This is the amount that this user will not be able to withdraw + assertEq(psm.convertToAssets(address(sDai), BURN_AMOUNT), 800); assertEq(psm.convertToShares(1e18), 1e18); vm.prank(user1); uint256 amount = psm.withdraw(address(sDai), 80e18); - assertEq(amount, 80e18); + assertEq(amount, 80e18 - 800); - assertEq(sDai.balanceOf(user1), 80e18); - assertEq(sDai.balanceOf(address(psm)), 0); + assertEq(sDai.balanceOf(user1), 80e18 - 800); + assertEq(sDai.balanceOf(address(psm)), 800); - assertEq(psm.totalShares(), 0); + assertEq(psm.totalShares(), BURN_AMOUNT); assertEq(psm.shares(user1), 0); + assertEq(psm.shares(burn), BURN_AMOUNT); assertEq(psm.convertToShares(1e18), 1e18); } @@ -82,7 +94,8 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(sDai.balanceOf(address(psm)), 100e18); assertEq(psm.totalShares(), 225e18); - assertEq(psm.shares(user1), 225e18); + assertEq(psm.shares(user1), 225e18 - BURN_AMOUNT); + assertEq(psm.shares(burn), BURN_AMOUNT); assertEq(psm.convertToShares(1e18), 1e18); @@ -98,23 +111,28 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(sDai.balanceOf(address(psm)), 100e18); assertEq(psm.totalShares(), 125e18); - assertEq(psm.shares(user1), 125e18); + assertEq(psm.shares(user1), 125e18 - BURN_AMOUNT); + assertEq(psm.shares(burn), BURN_AMOUNT); assertEq(psm.convertToShares(1e18), 1e18); + // This is the amount that this user will not be able to withdraw + assertEq(psm.convertToAssets(address(sDai), BURN_AMOUNT), 800); + vm.prank(user1); amount = psm.withdraw(address(sDai), 100e18); - assertEq(amount, 100e18); + assertEq(amount, 100e18 - 800); assertEq(usdc.balanceOf(user1), 100e6); assertEq(usdc.balanceOf(address(psm)), 0); - assertEq(sDai.balanceOf(user1), 100e18); - assertEq(sDai.balanceOf(address(psm)), 0); + assertEq(sDai.balanceOf(user1), 100e18 - 800); + assertEq(sDai.balanceOf(address(psm)), 800); - assertEq(psm.totalShares(), 0); + assertEq(psm.totalShares(), BURN_AMOUNT); assertEq(psm.shares(user1), 0); + assertEq(psm.shares(burn), BURN_AMOUNT); assertEq(psm.convertToShares(1e18), 1e18); } @@ -127,7 +145,8 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(usdc.balanceOf(address(psm)), 100e6); assertEq(psm.totalShares(), 225e18); - assertEq(psm.shares(user1), 225e18); + assertEq(psm.shares(user1), 225e18 - BURN_AMOUNT); + assertEq(psm.shares(burn), BURN_AMOUNT); assertEq(psm.convertToShares(1e18), 1e18); @@ -140,7 +159,8 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(usdc.balanceOf(address(psm)), 0); assertEq(psm.totalShares(), 125e18); // Only burns $100 of shares - assertEq(psm.shares(user1), 125e18); + assertEq(psm.shares(user1), 125e18 - BURN_AMOUNT); + assertEq(psm.shares(burn), BURN_AMOUNT); } function test_withdraw_amountHigherThanUserShares() public { @@ -178,11 +198,10 @@ contract PSMWithdrawTests is PSMTestBase { ) public { - // NOTE: Not covering zero cases, 1e-2 at 1e6 used as min for now so exact values can - // be asserted - depositAmount1 = bound(depositAmount1, 0, USDC_TOKEN_MAX); - depositAmount2 = bound(depositAmount2, 0, USDC_TOKEN_MAX); - depositAmount3 = bound(depositAmount3, 0, SDAI_TOKEN_MAX); + // NOTE: Deposits revert if deposit amount is less than the burn amount + depositAmount1 = bound(depositAmount1, BURN_AMOUNT, USDC_TOKEN_MAX); + depositAmount2 = bound(depositAmount2, BURN_AMOUNT, USDC_TOKEN_MAX); + depositAmount3 = bound(depositAmount3, BURN_AMOUNT, SDAI_TOKEN_MAX); withdrawAmount1 = bound(withdrawAmount1, 0, USDC_TOKEN_MAX); withdrawAmount2 = bound(withdrawAmount2, 0, USDC_TOKEN_MAX); @@ -198,7 +217,8 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(usdc.balanceOf(user1), 0); assertEq(usdc.balanceOf(address(psm)), totalUsdc); - assertEq(psm.shares(user1), depositAmount1 * 1e12); + assertEq(psm.shares(user1), depositAmount1 * 1e12 - BURN_AMOUNT); + assertEq(psm.shares(burn), BURN_AMOUNT); assertEq(psm.totalShares(), totalValue); uint256 expectedWithdrawnAmount1 @@ -220,8 +240,9 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(usdc.balanceOf(user2), 0); assertEq(usdc.balanceOf(address(psm)), totalUsdc - expectedWithdrawnAmount1); - assertEq(psm.shares(user1), (depositAmount1 - expectedWithdrawnAmount1) * 1e12); + assertEq(psm.shares(user1), (depositAmount1 - expectedWithdrawnAmount1) * 1e12 - BURN_AMOUNT); assertEq(psm.shares(user2), depositAmount2 * 1e12 + depositAmount3 * 125/100); // Includes sDAI deposit + assertEq(psm.shares(burn), BURN_AMOUNT); assertEq(psm.totalShares(), totalValue - expectedWithdrawnAmount1 * 1e12); uint256 expectedWithdrawnAmount2 @@ -246,7 +267,8 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(sDai.balanceOf(user2), 0); assertEq(sDai.balanceOf(address(psm)), depositAmount3); - assertEq(psm.shares(user1), (depositAmount1 - expectedWithdrawnAmount1) * 1e12); + assertEq(psm.shares(user1), (depositAmount1 - expectedWithdrawnAmount1) * 1e12 - BURN_AMOUNT); + assertEq(psm.shares(burn), BURN_AMOUNT); assertEq( psm.shares(user2), @@ -280,7 +302,8 @@ contract PSMWithdrawTests is PSMTestBase { assertApproxEqAbs(sDai.balanceOf(user2), expectedWithdrawnAmount3, 1); assertApproxEqAbs(sDai.balanceOf(address(psm)), depositAmount3 - expectedWithdrawnAmount3, 1); - assertEq(psm.shares(user1), (depositAmount1 - expectedWithdrawnAmount1) * 1e12); + assertEq(psm.shares(user1), (depositAmount1 - expectedWithdrawnAmount1) * 1e12 - BURN_AMOUNT); + assertEq(psm.shares(burn), BURN_AMOUNT); assertApproxEqAbs( psm.shares(user2), From d32fd2faf5d3d255075f4eef9d6639f1118f5eb3 Mon Sep 17 00:00:00 2001 From: lucas-manuel Date: Tue, 21 May 2024 11:10:08 -0400 Subject: [PATCH 05/16] feat: remove todos --- src/PSM.sol | 7 +++---- test/Withdraw.t.sol | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/PSM.sol b/src/PSM.sol index 604dfa1..ff39629 100644 --- a/src/PSM.sol +++ b/src/PSM.sol @@ -16,9 +16,6 @@ interface IRateProviderLike { // TODO: Refactor into inheritance structure // TODO: Add interface with natspec and inherit // TODO: Prove that we're always rounding against user -// TODO: Frontrunning attack, donation attack, virtual balances? -// TODO: Figure out how to optimize require checks for assets in view functions -// TODO: Discuss if we should add ERC20 functionality contract PSM { using SafeERC20 for IERC20; @@ -146,9 +143,11 @@ contract PSM { sharesToBurn = _convertToSharesRoundUp(_getAssetValue(asset, assetsWithdrawn)); + // TODO: Refactor this section to not use convertToAssets because of redundant check + // TODO: This can cause an underflow in shares, refactor to use full shares balance? if (sharesToBurn > shares[msg.sender]) { assetsWithdrawn = convertToAssets(asset, shares[msg.sender]); - sharesToBurn = _convertToSharesRoundUp(_getAssetValue(asset, assetsWithdrawn)); // TODO: This can cause an underflow, refactor to use full shares balance? + sharesToBurn = _convertToSharesRoundUp(_getAssetValue(asset, assetsWithdrawn)); } } diff --git a/test/Withdraw.t.sol b/test/Withdraw.t.sol index 421db3f..fb84371 100644 --- a/test/Withdraw.t.sol +++ b/test/Withdraw.t.sol @@ -357,7 +357,7 @@ contract PSMWithdrawTests is PSMTestBase { userAssets = userAssets * 1e27 / rateProvider.getConversionRate(); } - // Return the min of + // Return the min of assets, balance, and amount withdrawAmount = userAssets < balance ? userAssets : balance; withdrawAmount = amount < withdrawAmount ? amount : withdrawAmount; } From 2290417607390681c553a68cf2194f31951b567f Mon Sep 17 00:00:00 2001 From: lucas-manuel Date: Tue, 21 May 2024 13:07:15 -0400 Subject: [PATCH 06/16] fix: update to remove console and update comment --- src/PSM.sol | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/PSM.sol b/src/PSM.sol index ff39629..e5b363c 100644 --- a/src/PSM.sol +++ b/src/PSM.sol @@ -1,8 +1,6 @@ // SPDX-License-Identifier: AGPL-3.0-or-later pragma solidity ^0.8.13; -import { console2 } from "forge-std/console2.sol"; - import { IERC20 } from "erc20-helpers/interfaces/IERC20.sol"; import { SafeERC20 } from "erc20-helpers/SafeERC20.sol"; @@ -144,7 +142,7 @@ contract PSM { sharesToBurn = _convertToSharesRoundUp(_getAssetValue(asset, assetsWithdrawn)); // TODO: Refactor this section to not use convertToAssets because of redundant check - // TODO: This can cause an underflow in shares, refactor to use full shares balance? + // TODO: Can this cause an underflow in shares? Refactor to use full shares balance? if (sharesToBurn > shares[msg.sender]) { assetsWithdrawn = convertToAssets(asset, shares[msg.sender]); sharesToBurn = _convertToSharesRoundUp(_getAssetValue(asset, assetsWithdrawn)); From 54a4afefde8a33a64c23612e0f70d6760464f4a6 Mon Sep 17 00:00:00 2001 From: Lucas Date: Thu, 30 May 2024 07:33:54 -0400 Subject: [PATCH 07/16] feat: remove all share burn logic, get all non inflation attack tests to pass --- src/PSM.sol | 12 +----- test/Constructor.t.sol | 8 ++-- test/Deposit.t.sol | 40 ++++++----------- test/Getters.t.sol | 2 +- test/InflationAttack.t.sol | 6 +-- test/PSMTestBase.sol | 2 +- test/Withdraw.t.sol | 81 +++++++++++++---------------------- test/harnesses/PSMHarness.sol | 4 +- 8 files changed, 54 insertions(+), 101 deletions(-) diff --git a/src/PSM.sol b/src/PSM.sol index e5b363c..f411e4e 100644 --- a/src/PSM.sol +++ b/src/PSM.sol @@ -27,7 +27,6 @@ contract PSM { uint256 public immutable asset0Precision; uint256 public immutable asset1Precision; - uint256 public immutable initialBurnAmount; uint256 public totalShares; @@ -36,8 +35,7 @@ contract PSM { constructor( address asset0_, address asset1_, - address rateProvider_, - uint256 initialBurnAmount_ + address rateProvider_ ) { require(asset0_ != address(0), "PSM/invalid-asset0"); require(asset1_ != address(0), "PSM/invalid-asset1"); @@ -49,7 +47,6 @@ contract PSM { asset0Precision = 10 ** IERC20(asset0_).decimals(); asset1Precision = 10 ** IERC20(asset1_).decimals(); - initialBurnAmount = initialBurnAmount_; } /**********************************************************************************************/ @@ -89,13 +86,6 @@ contract PSM { { newShares = previewDeposit(asset, assetsToDeposit); - if (totalShares == 0 && initialBurnAmount != 0) { - shares[address(0)] += initialBurnAmount; - totalShares += initialBurnAmount; - - newShares -= initialBurnAmount; - } - shares[msg.sender] += newShares; totalShares += newShares; diff --git a/test/Constructor.t.sol b/test/Constructor.t.sol index d505441..cdec19d 100644 --- a/test/Constructor.t.sol +++ b/test/Constructor.t.sol @@ -11,22 +11,22 @@ contract PSMConstructorTests is PSMTestBase { function test_constructor_invalidAsset0() public { vm.expectRevert("PSM/invalid-asset0"); - new PSM(address(0), address(sDai), address(rateProvider), 1000); + new PSM(address(0), address(sDai), address(rateProvider)); } function test_constructor_invalidAsset1() public { vm.expectRevert("PSM/invalid-asset1"); - new PSM(address(usdc), address(0), address(rateProvider), 1000); + new PSM(address(usdc), address(0), address(rateProvider)); } function test_constructor_invalidRateProvider() public { vm.expectRevert("PSM/invalid-rateProvider"); - new PSM(address(sDai), address(usdc), address(0), 1000); + new PSM(address(sDai), address(usdc), address(0)); } function test_constructor() public { // Deploy new PSM to get test coverage - psm = new PSM(address(usdc), address(sDai), address(rateProvider), 1000); + psm = new PSM(address(usdc), address(sDai), address(rateProvider)); assertEq(address(psm.asset0()), address(usdc)); assertEq(address(psm.asset1()), address(sDai)); diff --git a/test/Deposit.t.sol b/test/Deposit.t.sol index f9924f8..24ce76a 100644 --- a/test/Deposit.t.sol +++ b/test/Deposit.t.sol @@ -11,9 +11,6 @@ contract PSMDepositTests is PSMTestBase { address user1 = makeAddr("user1"); address user2 = makeAddr("user2"); - address burn = address(0); - - uint256 BURN_AMOUNT = 1000; function test_deposit_notAsset0OrAsset1() public { vm.expectRevert("PSM/invalid-asset"); @@ -35,7 +32,6 @@ contract PSMDepositTests is PSMTestBase { assertEq(psm.totalShares(), 0); assertEq(psm.shares(user1), 0); - assertEq(psm.shares(burn), 0); assertEq(psm.convertToShares(1e18), 1e18); @@ -46,8 +42,7 @@ contract PSMDepositTests is PSMTestBase { assertEq(usdc.balanceOf(address(psm)), 100e6); assertEq(psm.totalShares(), 100e18); - assertEq(psm.shares(user1), 100e18 - BURN_AMOUNT); - assertEq(psm.shares(burn), BURN_AMOUNT); + assertEq(psm.shares(user1), 100e18); assertEq(psm.convertToShares(1e18), 1e18); } @@ -65,7 +60,6 @@ contract PSMDepositTests is PSMTestBase { assertEq(psm.totalShares(), 0); assertEq(psm.shares(user1), 0); - assertEq(psm.shares(burn), 0); assertEq(psm.convertToShares(1e18), 1e18); @@ -76,8 +70,7 @@ contract PSMDepositTests is PSMTestBase { assertEq(sDai.balanceOf(address(psm)), 100e18); assertEq(psm.totalShares(), 125e18); - assertEq(psm.shares(user1), 125e18 - BURN_AMOUNT); - assertEq(psm.shares(burn), BURN_AMOUNT); + assertEq(psm.shares(user1), 125e18); assertEq(psm.convertToShares(1e18), 1e18); } @@ -101,8 +94,7 @@ contract PSMDepositTests is PSMTestBase { assertEq(sDai.balanceOf(address(psm)), 0); assertEq(psm.totalShares(), 100e18); - assertEq(psm.shares(user1), 100e18 - BURN_AMOUNT); - assertEq(psm.shares(burn), BURN_AMOUNT); + assertEq(psm.shares(user1), 100e18); assertEq(psm.convertToShares(1e18), 1e18); @@ -115,16 +107,14 @@ contract PSMDepositTests is PSMTestBase { assertEq(sDai.balanceOf(address(psm)), 100e18); assertEq(psm.totalShares(), 225e18); - assertEq(psm.shares(user1), 225e18 - BURN_AMOUNT); - assertEq(psm.shares(burn), BURN_AMOUNT); // Only burn on first deposit + assertEq(psm.shares(user1), 225e18); assertEq(psm.convertToShares(1e18), 1e18); } function testFuzz_deposit_usdcThenSDai(uint256 usdcAmount, uint256 sDaiAmount) public { - // NOTE: Deposits revert if deposit amount is less than the burn amount - usdcAmount = _bound(usdcAmount, BURN_AMOUNT, USDC_TOKEN_MAX); - sDaiAmount = _bound(sDaiAmount, BURN_AMOUNT, SDAI_TOKEN_MAX); + usdcAmount = _bound(usdcAmount, 0, USDC_TOKEN_MAX); + sDaiAmount = _bound(sDaiAmount, 0, SDAI_TOKEN_MAX); usdc.mint(user1, usdcAmount); @@ -144,8 +134,7 @@ contract PSMDepositTests is PSMTestBase { assertEq(sDai.balanceOf(address(psm)), 0); assertEq(psm.totalShares(), usdcAmount * 1e12); - assertEq(psm.shares(user1), usdcAmount * 1e12 - BURN_AMOUNT); - assertEq(psm.shares(burn), BURN_AMOUNT); + assertEq(psm.shares(user1), usdcAmount * 1e12); assertEq(psm.convertToShares(1e18), 1e18); @@ -158,8 +147,7 @@ contract PSMDepositTests is PSMTestBase { assertEq(sDai.balanceOf(address(psm)), sDaiAmount); assertEq(psm.totalShares(), usdcAmount * 1e12 + sDaiAmount * 125/100); - assertEq(psm.shares(user1), usdcAmount * 1e12 + sDaiAmount * 125/100 - BURN_AMOUNT); - assertEq(psm.shares(burn), BURN_AMOUNT); // Only burn on first deposit + assertEq(psm.shares(user1), usdcAmount * 1e12 + sDaiAmount * 125/100); assertEq(psm.convertToShares(1e18), 1e18); } @@ -187,12 +175,11 @@ contract PSMDepositTests is PSMTestBase { assertEq(sDai.balanceOf(address(psm)), 100e18); assertEq(psm.totalShares(), 225e18); - assertEq(psm.shares(user1), 225e18 - BURN_AMOUNT); - assertEq(psm.shares(burn), BURN_AMOUNT); + assertEq(psm.shares(user1), 225e18); assertEq(psm.convertToShares(1e18), 1e18); - assertEq(psm.convertToAssetValue(psm.shares(user1)), 225e18 - BURN_AMOUNT); + assertEq(psm.convertToAssetValue(psm.shares(user1)), 225e18); rateProvider.__setConversionRate(1.5e27); @@ -212,7 +199,7 @@ contract PSMDepositTests is PSMTestBase { assertEq(sDai.balanceOf(user2), 100e18); assertEq(sDai.balanceOf(address(psm)), 100e18); - assertEq(psm.convertToAssetValue(psm.shares(user1)), 250e18 - 1112); // Burn amount conversion + assertEq(psm.convertToAssetValue(psm.shares(user1)), 250e18); assertEq(psm.convertToAssetValue(psm.shares(user2)), 0); assertEq(psm.getPsmTotalValue(), 250e18); @@ -231,12 +218,11 @@ contract PSMDepositTests is PSMTestBase { assertEq(expectedShares, 135e18); assertEq(psm.totalShares(), 360e18); - assertEq(psm.shares(user1), 225e18 - BURN_AMOUNT); + assertEq(psm.shares(user1), 225e18); assertEq(psm.shares(user2), 135e18); - assertEq(psm.shares(burn), BURN_AMOUNT); // User 1 earned $25 on 225, User 2 has earned nothing - assertEq(psm.convertToAssetValue(psm.shares(user1)), 250e18 - 1112); + assertEq(psm.convertToAssetValue(psm.shares(user1)), 250e18); assertEq(psm.convertToAssetValue(psm.shares(user2)), 150e18); assertEq(psm.getPsmTotalValue(), 400e18); diff --git a/test/Getters.t.sol b/test/Getters.t.sol index 895f445..911ff91 100644 --- a/test/Getters.t.sol +++ b/test/Getters.t.sol @@ -13,7 +13,7 @@ contract PSMHarnessTests is PSMTestBase { function setUp() public override { super.setUp(); - psmHarness = new PSMHarness(address(usdc), address(sDai), address(rateProvider), 1000); + psmHarness = new PSMHarness(address(usdc), address(sDai), address(rateProvider)); } function test_getAsset0Value() public view { diff --git a/test/InflationAttack.t.sol b/test/InflationAttack.t.sol index 9f0e189..f0b4c4d 100644 --- a/test/InflationAttack.t.sol +++ b/test/InflationAttack.t.sol @@ -10,7 +10,7 @@ import { PSMTestBase } from "test/PSMTestBase.sol"; contract InflationAttackTests is PSMTestBase { function test_inflationAttack_noInitialBurnAmount() public { - psm = new PSM(address(usdc), address(sDai), address(rateProvider), 0); + psm = new PSM(address(usdc), address(sDai), address(rateProvider)); address firstDepositor = makeAddr("firstDepositor"); address frontRunner = makeAddr("frontRunner"); @@ -55,7 +55,7 @@ contract InflationAttackTests is PSMTestBase { } function test_inflationAttack_useInitialBurnAmount_firstDepositOverflowBoundary() public { - psm = new PSM(address(usdc), address(sDai), address(rateProvider), 1000); + psm = new PSM(address(usdc), address(sDai), address(rateProvider)); address frontRunner = makeAddr("frontRunner"); @@ -71,7 +71,7 @@ contract InflationAttackTests is PSMTestBase { } function test_inflationAttack_useInitialBurnAmount() public { - psm = new PSM(address(usdc), address(sDai), address(rateProvider), 1000); + psm = new PSM(address(usdc), address(sDai), address(rateProvider)); address firstDepositor = makeAddr("firstDepositor"); address frontRunner = makeAddr("frontRunner"); diff --git a/test/PSMTestBase.sol b/test/PSMTestBase.sol index 8675d7c..625b685 100644 --- a/test/PSMTestBase.sol +++ b/test/PSMTestBase.sol @@ -38,7 +38,7 @@ contract PSMTestBase is Test { // NOTE: Using 1.25 for easy two way conversions rateProvider.__setConversionRate(1.25e27); - psm = new PSM(address(usdc), address(sDai), address(rateProvider), 1000); + psm = new PSM(address(usdc), address(sDai), address(rateProvider)); vm.label(address(sDai), "sDAI"); vm.label(address(usdc), "USDC"); diff --git a/test/Withdraw.t.sol b/test/Withdraw.t.sol index fb84371..bcabccb 100644 --- a/test/Withdraw.t.sol +++ b/test/Withdraw.t.sol @@ -13,9 +13,6 @@ contract PSMWithdrawTests is PSMTestBase { address user1 = makeAddr("user1"); address user2 = makeAddr("user2"); - address burn = address(0); - - uint256 BURN_AMOUNT = 1000; function test_withdraw_notAsset0OrAsset1() public { vm.expectRevert("PSM/invalid-asset"); @@ -31,24 +28,20 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(usdc.balanceOf(address(psm)), 100e6); assertEq(psm.totalShares(), 100e18); - assertEq(psm.shares(user1), 100e18 - BURN_AMOUNT); - assertEq(psm.shares(burn), BURN_AMOUNT); + assertEq(psm.shares(user1), 100e18); assertEq(psm.convertToShares(1e18), 1e18); vm.prank(user1); uint256 amount = psm.withdraw(address(usdc), 100e6); - // Burn amount causes shares to round down by one since shares are 99.999... - assertEq(amount, 100e6 - 1); + assertEq(amount, 100e6); - assertEq(usdc.balanceOf(user1), 100e6 - 1); - assertEq(usdc.balanceOf(address(psm)), 1); + assertEq(usdc.balanceOf(user1), 100e6); + assertEq(usdc.balanceOf(address(psm)), 0); - // User still has left over shares from rounding on 1e6 - assertEq(psm.totalShares(), 1e12); - assertEq(psm.shares(user1), 1e12 - BURN_AMOUNT); - assertEq(psm.shares(burn), BURN_AMOUNT); + assertEq(psm.totalShares(), 0); + assertEq(psm.shares(user1), 0); assertEq(psm.convertToShares(1e18), 1e18); } @@ -60,25 +53,20 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(sDai.balanceOf(address(psm)), 80e18); assertEq(psm.totalShares(), 100e18); - assertEq(psm.shares(user1), 100e18 - BURN_AMOUNT); - assertEq(psm.shares(burn), BURN_AMOUNT); - - // This is the amount that this user will not be able to withdraw - assertEq(psm.convertToAssets(address(sDai), BURN_AMOUNT), 800); + assertEq(psm.shares(user1), 100e18); assertEq(psm.convertToShares(1e18), 1e18); vm.prank(user1); uint256 amount = psm.withdraw(address(sDai), 80e18); - assertEq(amount, 80e18 - 800); + assertEq(amount, 80e18); - assertEq(sDai.balanceOf(user1), 80e18 - 800); - assertEq(sDai.balanceOf(address(psm)), 800); + assertEq(sDai.balanceOf(user1), 80e18); + assertEq(sDai.balanceOf(address(psm)), 0); - assertEq(psm.totalShares(), BURN_AMOUNT); + assertEq(psm.totalShares(), 0); assertEq(psm.shares(user1), 0); - assertEq(psm.shares(burn), BURN_AMOUNT); assertEq(psm.convertToShares(1e18), 1e18); } @@ -94,8 +82,7 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(sDai.balanceOf(address(psm)), 100e18); assertEq(psm.totalShares(), 225e18); - assertEq(psm.shares(user1), 225e18 - BURN_AMOUNT); - assertEq(psm.shares(burn), BURN_AMOUNT); + assertEq(psm.shares(user1), 225e18); assertEq(psm.convertToShares(1e18), 1e18); @@ -111,28 +98,23 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(sDai.balanceOf(address(psm)), 100e18); assertEq(psm.totalShares(), 125e18); - assertEq(psm.shares(user1), 125e18 - BURN_AMOUNT); - assertEq(psm.shares(burn), BURN_AMOUNT); + assertEq(psm.shares(user1), 125e18); assertEq(psm.convertToShares(1e18), 1e18); - // This is the amount that this user will not be able to withdraw - assertEq(psm.convertToAssets(address(sDai), BURN_AMOUNT), 800); - vm.prank(user1); amount = psm.withdraw(address(sDai), 100e18); - assertEq(amount, 100e18 - 800); + assertEq(amount, 100e18); assertEq(usdc.balanceOf(user1), 100e6); assertEq(usdc.balanceOf(address(psm)), 0); - assertEq(sDai.balanceOf(user1), 100e18 - 800); - assertEq(sDai.balanceOf(address(psm)), 800); + assertEq(sDai.balanceOf(user1), 100e18); + assertEq(sDai.balanceOf(address(psm)), 0); - assertEq(psm.totalShares(), BURN_AMOUNT); + assertEq(psm.totalShares(), 0); assertEq(psm.shares(user1), 0); - assertEq(psm.shares(burn), BURN_AMOUNT); assertEq(psm.convertToShares(1e18), 1e18); } @@ -145,8 +127,7 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(usdc.balanceOf(address(psm)), 100e6); assertEq(psm.totalShares(), 225e18); - assertEq(psm.shares(user1), 225e18 - BURN_AMOUNT); - assertEq(psm.shares(burn), BURN_AMOUNT); + assertEq(psm.shares(user1), 225e18); assertEq(psm.convertToShares(1e18), 1e18); @@ -159,8 +140,7 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(usdc.balanceOf(address(psm)), 0); assertEq(psm.totalShares(), 125e18); // Only burns $100 of shares - assertEq(psm.shares(user1), 125e18 - BURN_AMOUNT); - assertEq(psm.shares(burn), BURN_AMOUNT); + assertEq(psm.shares(user1), 125e18); } function test_withdraw_amountHigherThanUserShares() public { @@ -198,10 +178,11 @@ contract PSMWithdrawTests is PSMTestBase { ) public { - // NOTE: Deposits revert if deposit amount is less than the burn amount - depositAmount1 = bound(depositAmount1, BURN_AMOUNT, USDC_TOKEN_MAX); - depositAmount2 = bound(depositAmount2, BURN_AMOUNT, USDC_TOKEN_MAX); - depositAmount3 = bound(depositAmount3, BURN_AMOUNT, SDAI_TOKEN_MAX); + // NOTE: Not covering zero cases, 1e-2 at 1e6 used as min for now so exact values can + // be asserted + depositAmount1 = bound(depositAmount1, 0, USDC_TOKEN_MAX); + depositAmount2 = bound(depositAmount2, 0, USDC_TOKEN_MAX); + depositAmount3 = bound(depositAmount3, 0, SDAI_TOKEN_MAX); withdrawAmount1 = bound(withdrawAmount1, 0, USDC_TOKEN_MAX); withdrawAmount2 = bound(withdrawAmount2, 0, USDC_TOKEN_MAX); @@ -217,8 +198,7 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(usdc.balanceOf(user1), 0); assertEq(usdc.balanceOf(address(psm)), totalUsdc); - assertEq(psm.shares(user1), depositAmount1 * 1e12 - BURN_AMOUNT); - assertEq(psm.shares(burn), BURN_AMOUNT); + assertEq(psm.shares(user1), depositAmount1 * 1e12); assertEq(psm.totalShares(), totalValue); uint256 expectedWithdrawnAmount1 @@ -240,9 +220,8 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(usdc.balanceOf(user2), 0); assertEq(usdc.balanceOf(address(psm)), totalUsdc - expectedWithdrawnAmount1); - assertEq(psm.shares(user1), (depositAmount1 - expectedWithdrawnAmount1) * 1e12 - BURN_AMOUNT); + assertEq(psm.shares(user1), (depositAmount1 - expectedWithdrawnAmount1) * 1e12); assertEq(psm.shares(user2), depositAmount2 * 1e12 + depositAmount3 * 125/100); // Includes sDAI deposit - assertEq(psm.shares(burn), BURN_AMOUNT); assertEq(psm.totalShares(), totalValue - expectedWithdrawnAmount1 * 1e12); uint256 expectedWithdrawnAmount2 @@ -267,8 +246,7 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(sDai.balanceOf(user2), 0); assertEq(sDai.balanceOf(address(psm)), depositAmount3); - assertEq(psm.shares(user1), (depositAmount1 - expectedWithdrawnAmount1) * 1e12 - BURN_AMOUNT); - assertEq(psm.shares(burn), BURN_AMOUNT); + assertEq(psm.shares(user1), (depositAmount1 - expectedWithdrawnAmount1) * 1e12); assertEq( psm.shares(user2), @@ -302,8 +280,7 @@ contract PSMWithdrawTests is PSMTestBase { assertApproxEqAbs(sDai.balanceOf(user2), expectedWithdrawnAmount3, 1); assertApproxEqAbs(sDai.balanceOf(address(psm)), depositAmount3 - expectedWithdrawnAmount3, 1); - assertEq(psm.shares(user1), (depositAmount1 - expectedWithdrawnAmount1) * 1e12 - BURN_AMOUNT); - assertEq(psm.shares(burn), BURN_AMOUNT); + assertEq(psm.shares(user1), (depositAmount1 - expectedWithdrawnAmount1) * 1e12); assertApproxEqAbs( psm.shares(user2), @@ -333,7 +310,7 @@ contract PSMWithdrawTests is PSMTestBase { // ); } - function _checkPsmInvariant() internal view { + function _checkPsmInvariant() internal { uint256 totalSharesValue = psm.convertToAssetValue(psm.totalShares()); uint256 totalAssetsValue = sDai.balanceOf(address(psm)) * rateProvider.getConversionRate() / 1e27 diff --git a/test/harnesses/PSMHarness.sol b/test/harnesses/PSMHarness.sol index 48dfcf9..6a5eaa7 100644 --- a/test/harnesses/PSMHarness.sol +++ b/test/harnesses/PSMHarness.sol @@ -5,8 +5,8 @@ import { PSM } from "src/PSM.sol"; contract PSMHarness is PSM { - constructor(address asset0_, address asset1_, address rateProvider_, uint256 initialBurnAmount_) - PSM(asset0_, asset1_, rateProvider_, initialBurnAmount_) {} + constructor(address asset0_, address asset1_, address rateProvider_) + PSM(asset0_, asset1_, rateProvider_) {} function getAssetValue(address asset, uint256 amount) external view returns (uint256) { return _getAssetValue(asset, amount); From ce5653d46c2db8f39f0c7e1466ad93fd7a00b14e Mon Sep 17 00:00:00 2001 From: Lucas Date: Thu, 30 May 2024 07:35:15 -0400 Subject: [PATCH 08/16] fix: cleanup diff --- src/PSM.sol | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/PSM.sol b/src/PSM.sol index f411e4e..5dcc952 100644 --- a/src/PSM.sol +++ b/src/PSM.sol @@ -45,8 +45,8 @@ contract PSM { asset1 = IERC20(asset1_); rateProvider = rateProvider_; - asset0Precision = 10 ** IERC20(asset0_).decimals(); - asset1Precision = 10 ** IERC20(asset1_).decimals(); + asset0Precision = 10 ** IERC20(asset0_).decimals(); + asset1Precision = 10 ** IERC20(asset1_).decimals(); } /**********************************************************************************************/ From 259cf16092399205e36a12d43ba289c7d02c7cce Mon Sep 17 00:00:00 2001 From: Lucas Date: Thu, 30 May 2024 07:51:47 -0400 Subject: [PATCH 09/16] fix: update to use initial deposit instead of burn --- src/PSM.sol | 1 + test/InflationAttack.t.sol | 35 +++++++++++------------------------ 2 files changed, 12 insertions(+), 24 deletions(-) diff --git a/src/PSM.sol b/src/PSM.sol index 5dcc952..679c5c2 100644 --- a/src/PSM.sol +++ b/src/PSM.sol @@ -14,6 +14,7 @@ interface IRateProviderLike { // TODO: Refactor into inheritance structure // TODO: Add interface with natspec and inherit // TODO: Prove that we're always rounding against user +// TODO: Add receiver to deposit/withdraw contract PSM { using SafeERC20 for IERC20; diff --git a/test/InflationAttack.t.sol b/test/InflationAttack.t.sol index f0b4c4d..6c0c047 100644 --- a/test/InflationAttack.t.sol +++ b/test/InflationAttack.t.sol @@ -9,7 +9,9 @@ import { PSMTestBase } from "test/PSMTestBase.sol"; contract InflationAttackTests is PSMTestBase { - function test_inflationAttack_noInitialBurnAmount() public { + // TODO: Add DOS attack test outlined here: https://github.com/marsfoundation/spark-psm/pull/2#pullrequestreview-2085880206 + + function test_inflationAttack_noInitialDeposit() public { psm = new PSM(address(usdc), address(sDai), address(rateProvider)); address firstDepositor = makeAddr("firstDepositor"); @@ -54,32 +56,19 @@ contract InflationAttackTests is PSMTestBase { assertEq(usdc.balanceOf(frontRunner), 15_000_000e6); } - function test_inflationAttack_useInitialBurnAmount_firstDepositOverflowBoundary() public { - psm = new PSM(address(usdc), address(sDai), address(rateProvider)); - - address frontRunner = makeAddr("frontRunner"); - - vm.startPrank(frontRunner); - sDai.mint(frontRunner, 800); - sDai.approve(address(psm), 800); - - vm.expectRevert(stdError.arithmeticError); - psm.deposit(address(sDai), 799); - - // 800 sDAI = 1000 shares - psm.deposit(address(sDai), 800); - } - - function test_inflationAttack_useInitialBurnAmount() public { + function test_inflationAttack_useInitialDeposit() public { psm = new PSM(address(usdc), address(sDai), address(rateProvider)); address firstDepositor = makeAddr("firstDepositor"); address frontRunner = makeAddr("frontRunner"); + address deployer = address(this); // TODO: Update to use non-deployer receiver + + _deposit(address(this), address(sDai), 800); /// 1000 shares // Step 1: Front runner deposits 801 sDAI to get 1 share - // 1000 shares get burned, user is left with 1 - _deposit(frontRunner, address(sDai), 801); + // User tries to do the same attack, depositing one sDAI for 1 share + _deposit(frontRunner, address(sDai), 1); assertEq(psm.shares(frontRunner), 1); @@ -107,14 +96,12 @@ contract InflationAttackTests is PSMTestBase { _withdraw(firstDepositor, address(usdc), type(uint256).max); _withdraw(frontRunner, address(usdc), type(uint256).max); - - // Burnt shares have a claim on these - // TODO: Should this be an admin contract instead of address(0)? - assertEq(usdc.balanceOf(address(psm)), 9_993_337.774818e6); + _withdraw(deployer, address(usdc), type(uint256).max); // Front runner loses 9.99m USDC, first depositor loses 4k USDC assertEq(usdc.balanceOf(firstDepositor), 19_996_668.887408e6); assertEq(usdc.balanceOf(frontRunner), 9_993.337774e6); + assertEq(usdc.balanceOf(deployer), 9_993_337.774818e6); } } From 62ef5c2fdb22a06544a34f54e8552b83541a84f4 Mon Sep 17 00:00:00 2001 From: Lucas Date: Thu, 30 May 2024 08:00:44 -0400 Subject: [PATCH 10/16] feat: add readme section explaining attack --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 13ab5fe..e3336fb 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ PSM contracts to either: - Convert between a tokenization of an asset (ex. USDC) and a yield-bearing version of the asset (ex. sDAI). - Convert one to one between directly correlated assets (ex. USDC-DAI). +## [CRITICAL]: First Depositor Attack Prevention on Deployment + +On the deployment of the PSM, the deployer **MUST make an initial deposit in order to protect the first depositor from getting attacked with a share inflation attack**. This is outlined further [here](https://github.com/marsfoundation/spark-automations/assets/44272939/9472a6d2-0361-48b0-b534-96a0614330d3). 1000 shares minted is determined to be sufficient to prevent this attack. Technical details related to this can be found in `test/InflationAttack.t.sol`. The deployment script [TODO] in this repo contains logic for the deployer to perform this initial deposit, so it is **HIGHLY RECOMMENDED** to use this deployment script when deploying the PSM. Reasoning for the technical implementation approach taken is outlined in more detail [here](https://github.com/marsfoundation/spark-psm/pull/2). + ## Usage ```bash From 07201506a1009dd92de775096ec21da3c01387ac Mon Sep 17 00:00:00 2001 From: Lucas Date: Thu, 30 May 2024 08:01:48 -0400 Subject: [PATCH 11/16] fix: minimize diff --- src/PSM.sol | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/PSM.sol b/src/PSM.sol index 679c5c2..d9bd589 100644 --- a/src/PSM.sol +++ b/src/PSM.sol @@ -33,11 +33,7 @@ contract PSM { mapping(address user => uint256 shares) public shares; - constructor( - address asset0_, - address asset1_, - address rateProvider_ - ) { + constructor(address asset0_, address asset1_, address rateProvider_) { require(asset0_ != address(0), "PSM/invalid-asset0"); require(asset1_ != address(0), "PSM/invalid-asset1"); require(rateProvider_ != address(0), "PSM/invalid-rateProvider"); From 44a5d52162acb0acfb749c4c6d0e93b33d300828 Mon Sep 17 00:00:00 2001 From: Lucas Date: Mon, 3 Jun 2024 09:06:46 -0400 Subject: [PATCH 12/16] feat: update to address comments outside sharesToBurn --- src/PSM.sol | 22 +++++++++++----------- test/Withdraw.t.sol | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/PSM.sol b/src/PSM.sol index d9bd589..72fd270 100644 --- a/src/PSM.sol +++ b/src/PSM.sol @@ -108,11 +108,11 @@ contract PSM { /*** Deposit/withdraw preview functions ***/ /**********************************************************************************************/ - function previewDeposit(address asset, uint256 assets) public view returns (uint256) { + function previewDeposit(address asset, uint256 assetsToDeposit) public view returns (uint256) { require(asset == address(asset0) || asset == address(asset1), "PSM/invalid-asset"); // Convert amount to 1e18 precision denominated in value of asset0 then convert to shares. - return convertToShares(_getAssetValue(asset, assets)); + return convertToShares(_getAssetValue(asset, assetsToDeposit)); } function previewWithdraw(address asset, uint256 maxAssetsToWithdraw) @@ -128,10 +128,10 @@ contract PSM { sharesToBurn = _convertToSharesRoundUp(_getAssetValue(asset, assetsWithdrawn)); - // TODO: Refactor this section to not use convertToAssets because of redundant check - // TODO: Can this cause an underflow in shares? Refactor to use full shares balance? - if (sharesToBurn > shares[msg.sender]) { - assetsWithdrawn = convertToAssets(asset, shares[msg.sender]); + uint256 userShares = shares[msg.sender]; + + if (sharesToBurn > userShares) { + assetsWithdrawn = convertToAssets(asset, userShares); sharesToBurn = _convertToSharesRoundUp(_getAssetValue(asset, assetsWithdrawn)); } } @@ -203,15 +203,15 @@ contract PSM { function _convertToSharesRoundUp(uint256 assetValue) internal view returns (uint256) { uint256 totalValue = getPsmTotalValue(); if (totalValue != 0) { - return _divRoundUp(assetValue * totalShares, totalValue); + return _divUp(assetValue * totalShares, totalValue); } return assetValue; } - function _divRoundUp(uint256 numerator_, uint256 divisor_) - internal pure returns (uint256 result_) - { - result_ = (numerator_ + divisor_ - 1) / divisor_; + function _divUp(uint256 x, uint256 y) internal pure returns (uint256 z) { + unchecked { + z = x != 0 ? ((x - 1) / y) + 1 : 0; + } } function _getAssetValue(address asset, uint256 amount) internal view returns (uint256) { diff --git a/test/Withdraw.t.sol b/test/Withdraw.t.sol index bcabccb..453bb27 100644 --- a/test/Withdraw.t.sol +++ b/test/Withdraw.t.sol @@ -310,7 +310,7 @@ contract PSMWithdrawTests is PSMTestBase { // ); } - function _checkPsmInvariant() internal { + function _checkPsmInvariant() internal view { uint256 totalSharesValue = psm.convertToAssetValue(psm.totalShares()); uint256 totalAssetsValue = sDai.balanceOf(address(psm)) * rateProvider.getConversionRate() / 1e27 From 9c3958e204ee6b325249441bedd54197f0331647 Mon Sep 17 00:00:00 2001 From: Lucas Date: Mon, 3 Jun 2024 09:18:34 -0400 Subject: [PATCH 13/16] feat: update inflation attack test and readme --- README.md | 2 +- test/InflationAttack.t.sol | 28 ++++++++++++++++------------ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index e3336fb..ff69c49 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ PSM contracts to either: ## [CRITICAL]: First Depositor Attack Prevention on Deployment -On the deployment of the PSM, the deployer **MUST make an initial deposit in order to protect the first depositor from getting attacked with a share inflation attack**. This is outlined further [here](https://github.com/marsfoundation/spark-automations/assets/44272939/9472a6d2-0361-48b0-b534-96a0614330d3). 1000 shares minted is determined to be sufficient to prevent this attack. Technical details related to this can be found in `test/InflationAttack.t.sol`. The deployment script [TODO] in this repo contains logic for the deployer to perform this initial deposit, so it is **HIGHLY RECOMMENDED** to use this deployment script when deploying the PSM. Reasoning for the technical implementation approach taken is outlined in more detail [here](https://github.com/marsfoundation/spark-psm/pull/2). +On the deployment of the PSM, the deployer **MUST make an initial deposit to get at least 1e18 shares in order to protect the first depositor from getting attacked with a share inflation attack**. This is outlined further [here](https://github.com/marsfoundation/spark-automations/assets/44272939/9472a6d2-0361-48b0-b534-96a0614330d3). 1e18 shares minted is determined to be sufficient to prevent this attack. Technical details related to this can be found in `test/InflationAttack.t.sol`. The deployment script [TODO] in this repo contains logic for the deployer to perform this initial deposit, so it is **HIGHLY RECOMMENDED** to use this deployment script when deploying the PSM. Reasoning for the technical implementation approach taken is outlined in more detail [here](https://github.com/marsfoundation/spark-psm/pull/2). ## Usage diff --git a/test/InflationAttack.t.sol b/test/InflationAttack.t.sol index 6c0c047..8cf773f 100644 --- a/test/InflationAttack.t.sol +++ b/test/InflationAttack.t.sol @@ -28,6 +28,8 @@ contract InflationAttackTests is PSMTestBase { deal(address(usdc), frontRunner, 10_000_000e6); + assertEq(psm.convertToAssetValue(1), 1); + vm.prank(frontRunner); usdc.transfer(address(psm), 10_000_000e6); @@ -63,9 +65,9 @@ contract InflationAttackTests is PSMTestBase { address frontRunner = makeAddr("frontRunner"); address deployer = address(this); // TODO: Update to use non-deployer receiver - _deposit(address(this), address(sDai), 800); /// 1000 shares + _deposit(address(this), address(sDai), 0.8e18); /// 1e18 shares - // Step 1: Front runner deposits 801 sDAI to get 1 share + // Step 1: Front runner deposits sDAI to get 1 share // User tries to do the same attack, depositing one sDAI for 1 share _deposit(frontRunner, address(sDai), 1); @@ -74,23 +76,25 @@ contract InflationAttackTests is PSMTestBase { // Step 2: Front runner transfers 10m USDC to inflate the exchange rate to 1:(10m + 1) + assertEq(psm.convertToAssetValue(1), 1); + deal(address(usdc), frontRunner, 10_000_000e6); vm.prank(frontRunner); usdc.transfer(address(psm), 10_000_000e6); - // Much less inflated exchange rate - assertEq(psm.convertToAssetValue(1), 9990.009990009990009991e18); + // Still inflated, but all value is transferred to existing holder, deployer + assertEq(psm.convertToAssetValue(1), 0.00000000001e18); - // Step 3: First depositor deposits 20 million USDC, only gets one share because rounding - // error gives them 1 instead of 2 shares, worth 15m USDC + // Step 3: First depositor deposits 20 million USDC, this time rounding is not an issue + // so value reflected is much more accurate _deposit(firstDepositor, address(usdc), 20_000_000e6); - assertEq(psm.shares(firstDepositor), 2001); + assertEq(psm.shares(firstDepositor), 1.999999800000020001e18); // Higher amount of initial shares means lower rounding error - assertEq(psm.convertToAssetValue(2001), 19_996_668.887408394403731513e18); + assertEq(psm.convertToAssetValue(1.999999800000020001e18), 19_999_999.999999999996673334e18); // Step 4: Both users withdraw the max amount of funds they can @@ -98,10 +102,10 @@ contract InflationAttackTests is PSMTestBase { _withdraw(frontRunner, address(usdc), type(uint256).max); _withdraw(deployer, address(usdc), type(uint256).max); - // Front runner loses 9.99m USDC, first depositor loses 4k USDC - assertEq(usdc.balanceOf(firstDepositor), 19_996_668.887408e6); - assertEq(usdc.balanceOf(frontRunner), 9_993.337774e6); - assertEq(usdc.balanceOf(deployer), 9_993_337.774818e6); + // Front runner loses full 10m USDC to the deployer that had all shares at the beginning, first depositor loses nothing (1e-6 USDC) + assertEq(usdc.balanceOf(firstDepositor), 19_999_999.999999e6); + assertEq(usdc.balanceOf(frontRunner), 0); + assertEq(usdc.balanceOf(deployer), 10_000_000.000001e6); } } From 4c9cf09e2bdbd87b8a60d7f3b5b9935d2921b81d Mon Sep 17 00:00:00 2001 From: Lucas Date: Mon, 3 Jun 2024 09:19:31 -0400 Subject: [PATCH 14/16] fix: update readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ff69c49..f7d0ed8 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ PSM contracts to either: ## [CRITICAL]: First Depositor Attack Prevention on Deployment -On the deployment of the PSM, the deployer **MUST make an initial deposit to get at least 1e18 shares in order to protect the first depositor from getting attacked with a share inflation attack**. This is outlined further [here](https://github.com/marsfoundation/spark-automations/assets/44272939/9472a6d2-0361-48b0-b534-96a0614330d3). 1e18 shares minted is determined to be sufficient to prevent this attack. Technical details related to this can be found in `test/InflationAttack.t.sol`. The deployment script [TODO] in this repo contains logic for the deployer to perform this initial deposit, so it is **HIGHLY RECOMMENDED** to use this deployment script when deploying the PSM. Reasoning for the technical implementation approach taken is outlined in more detail [here](https://github.com/marsfoundation/spark-psm/pull/2). +On the deployment of the PSM, the deployer **MUST make an initial deposit to get AT LEAST 1e18 shares in order to protect the first depositor from getting attacked with a share inflation attack**. This is outlined further [here](https://github.com/marsfoundation/spark-automations/assets/44272939/9472a6d2-0361-48b0-b534-96a0614330d3). Technical details related to this can be found in `test/InflationAttack.t.sol`. The deployment script [TODO] in this repo contains logic for the deployer to perform this initial deposit, so it is **HIGHLY RECOMMENDED** to use this deployment script when deploying the PSM. Reasoning for the technical implementation approach taken is outlined in more detail [here](https://github.com/marsfoundation/spark-psm/pull/2). ## Usage From cb9931ad4d3ce00ccfe22ea55469b183af8f598b Mon Sep 17 00:00:00 2001 From: Lucas Date: Mon, 3 Jun 2024 09:27:22 -0400 Subject: [PATCH 15/16] feat: update test to constrain deposit/withdraw --- src/PSM.sol | 2 +- test/Withdraw.t.sol | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/PSM.sol b/src/PSM.sol index 72fd270..d3c13dd 100644 --- a/src/PSM.sol +++ b/src/PSM.sol @@ -132,7 +132,7 @@ contract PSM { if (sharesToBurn > userShares) { assetsWithdrawn = convertToAssets(asset, userShares); - sharesToBurn = _convertToSharesRoundUp(_getAssetValue(asset, assetsWithdrawn)); + sharesToBurn = userShares; } } diff --git a/test/Withdraw.t.sol b/test/Withdraw.t.sol index 453bb27..28b6731 100644 --- a/test/Withdraw.t.sol +++ b/test/Withdraw.t.sol @@ -168,7 +168,11 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(psm.shares(user2), 0); // Burns the users full amount of shares } - function testFuzz_withdraw_multiUser( + // Adding this test to demonstrate that numbers are exact and correspond to assets deposits/withdrawals when withdrawals + // aren't greater than the user's share balance. The next test doesn't constrain this, but there are rounding errors of + // up to 1e12 for USDC because of the difference in asset precision. Up to 1e12 shares can be burned for 0 USDC in some + // cases, but this is an intentional rounding error against the user. + function testFuzz_withdraw_multiUser_noFullShareBurns( uint256 depositAmount1, uint256 depositAmount2, uint256 depositAmount3, @@ -184,9 +188,9 @@ contract PSMWithdrawTests is PSMTestBase { depositAmount2 = bound(depositAmount2, 0, USDC_TOKEN_MAX); depositAmount3 = bound(depositAmount3, 0, SDAI_TOKEN_MAX); - withdrawAmount1 = bound(withdrawAmount1, 0, USDC_TOKEN_MAX); - withdrawAmount2 = bound(withdrawAmount2, 0, USDC_TOKEN_MAX); - withdrawAmount3 = bound(withdrawAmount3, 0, SDAI_TOKEN_MAX); + withdrawAmount1 = bound(withdrawAmount1, depositAmount1, USDC_TOKEN_MAX); + withdrawAmount2 = bound(withdrawAmount2, depositAmount2, USDC_TOKEN_MAX); + withdrawAmount3 = bound(withdrawAmount3, 0, SDAI_TOKEN_MAX); _deposit(user1, address(usdc), depositAmount1); _deposit(user2, address(usdc), depositAmount2); From e576672c71e6043f7ce806d3adb356f0901f9398 Mon Sep 17 00:00:00 2001 From: Lucas Date: Mon, 3 Jun 2024 09:45:55 -0400 Subject: [PATCH 16/16] feat: update to add both cases --- test/Withdraw.t.sol | 92 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 79 insertions(+), 13 deletions(-) diff --git a/test/Withdraw.t.sol b/test/Withdraw.t.sol index 28b6731..66fe2cb 100644 --- a/test/Withdraw.t.sol +++ b/test/Withdraw.t.sol @@ -188,10 +188,68 @@ contract PSMWithdrawTests is PSMTestBase { depositAmount2 = bound(depositAmount2, 0, USDC_TOKEN_MAX); depositAmount3 = bound(depositAmount3, 0, SDAI_TOKEN_MAX); - withdrawAmount1 = bound(withdrawAmount1, depositAmount1, USDC_TOKEN_MAX); - withdrawAmount2 = bound(withdrawAmount2, depositAmount2, USDC_TOKEN_MAX); - withdrawAmount3 = bound(withdrawAmount3, 0, SDAI_TOKEN_MAX); + withdrawAmount1 = bound(withdrawAmount1, 0, USDC_TOKEN_MAX); + withdrawAmount2 = bound(withdrawAmount2, 0, depositAmount2); // User can't burn up to 1e12 shares for 0 USDC in this case + withdrawAmount3 = bound(withdrawAmount3, 0, SDAI_TOKEN_MAX); + + // Run with zero share tolerance because the rounding error shouldn't be introduced with the above constraints. + _runWithdrawFuzzTests( + 0, + depositAmount1, + depositAmount2, + depositAmount3, + withdrawAmount1, + withdrawAmount2, + withdrawAmount3 + ); + } + + function testFuzz_withdraw_multiUser_fullShareBurns( + uint256 depositAmount1, + uint256 depositAmount2, + uint256 depositAmount3, + uint256 withdrawAmount1, + uint256 withdrawAmount2, + uint256 withdrawAmount3 + ) + public + { + // NOTE: Not covering zero cases, 1e-2 at 1e6 used as min for now so exact values can + // be asserted + depositAmount1 = bound(depositAmount1, 0, USDC_TOKEN_MAX); + depositAmount2 = bound(depositAmount2, 0, USDC_TOKEN_MAX); + depositAmount3 = bound(depositAmount3, 0, SDAI_TOKEN_MAX); + withdrawAmount1 = bound(withdrawAmount1, 0, USDC_TOKEN_MAX); + withdrawAmount2 = bound(withdrawAmount2, 0, USDC_TOKEN_MAX); + withdrawAmount3 = bound(withdrawAmount3, 0, SDAI_TOKEN_MAX); + + // Run with 1e12 share tolerance because the rounding error will be introduced with the above constraints. + _runWithdrawFuzzTests( + 1e12, + depositAmount1, + depositAmount2, + depositAmount3, + withdrawAmount1, + withdrawAmount2, + withdrawAmount3 + ); + } + + // NOTE: For `assertApproxEqAbs` assertions, a difference calculation is used here instead of comparing + // the two values because this approach inherently asserts that the shares remaining are lower than the + // theoretical value, proving the PSM rounds agains the user. + function _runWithdrawFuzzTests( + uint256 usdcShareTolerance, + uint256 depositAmount1, + uint256 depositAmount2, + uint256 depositAmount3, + uint256 withdrawAmount1, + uint256 withdrawAmount2, + uint256 withdrawAmount3 + ) + internal + { _deposit(user1, address(usdc), depositAmount1); _deposit(user2, address(usdc), depositAmount2); _deposit(user2, address(sDai), depositAmount3); @@ -220,6 +278,9 @@ contract PSMWithdrawTests is PSMTestBase { totalValue ); + // NOTE: User 1 doesn't need a tolerance because their shares are 1e6 precision because they only + // deposited USDC. User 2 has a tolerance because they deposited sDAI which has 1e18 precision + // so there is a chance that the rounding will be off by up to 1e12. assertEq(usdc.balanceOf(user1), expectedWithdrawnAmount1); assertEq(usdc.balanceOf(user2), 0); assertEq(usdc.balanceOf(address(psm)), totalUsdc - expectedWithdrawnAmount1); @@ -252,12 +313,17 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(psm.shares(user1), (depositAmount1 - expectedWithdrawnAmount1) * 1e12); - assertEq( - psm.shares(user2), - (depositAmount2 * 1e12) + (depositAmount3 * 125/100) - (expectedWithdrawnAmount2 * 1e12) + assertApproxEqAbs( + ((depositAmount2 * 1e12) + (depositAmount3 * 125/100) - (expectedWithdrawnAmount2 * 1e12)) - psm.shares(user2), + 0, + usdcShareTolerance ); - assertEq(psm.totalShares(), totalValue - (expectedWithdrawnAmount1 + expectedWithdrawnAmount2) * 1e12); + assertApproxEqAbs( + (totalValue - (expectedWithdrawnAmount1 + expectedWithdrawnAmount2) * 1e12) - psm.totalShares(), + 0, + usdcShareTolerance + ); uint256 expectedWithdrawnAmount3 = _getExpectedWithdrawnAmount(sDai, user2, withdrawAmount3); @@ -287,15 +353,15 @@ contract PSMWithdrawTests is PSMTestBase { assertEq(psm.shares(user1), (depositAmount1 - expectedWithdrawnAmount1) * 1e12); assertApproxEqAbs( - psm.shares(user2), - (depositAmount2 * 1e12) + (depositAmount3 * 125/100) - (expectedWithdrawnAmount2 * 1e12) - (expectedWithdrawnAmount3 * 125/100), - 1 + ((depositAmount2 * 1e12) + (depositAmount3 * 125/100) - (expectedWithdrawnAmount2 * 1e12) - (expectedWithdrawnAmount3 * 125/100)) - psm.shares(user2), + 0, + usdcShareTolerance + 1 // 1 is added to the tolerance because of rounding error in sDAI calculations ); assertApproxEqAbs( - psm.totalShares(), - totalValue - (expectedWithdrawnAmount1 + expectedWithdrawnAmount2) * 1e12 - (expectedWithdrawnAmount3 * 125/100), - 1 + totalValue - (expectedWithdrawnAmount1 + expectedWithdrawnAmount2) * 1e12 - (expectedWithdrawnAmount3 * 125/100) - psm.totalShares(), + 0, + usdcShareTolerance + 1 // 1 is added to the tolerance because of rounding error in sDAI calculations ); // -- TODO: Get these to work, rounding assertions proving always rounding down