Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add minimum first deposit, refactor convert/preview functions (SC-446) #2

Merged
merged 16 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ PSM contracts to either:
- Convert between a tokenization of an asset (ex. USDC) and a yield-bearing version of the asset (ex. sDAI).
- Convert one to one between directly correlated assets (ex. USDC-DAI).

## [CRITICAL]: First Depositor Attack Prevention on Deployment

On the deployment of the PSM, the deployer **MUST make an initial deposit to get AT LEAST 1e18 shares in order to protect the first depositor from getting attacked with a share inflation attack**. This is outlined further [here](https://github.com/marsfoundation/spark-automations/assets/44272939/9472a6d2-0361-48b0-b534-96a0614330d3). Technical details related to this can be found in `test/InflationAttack.t.sol`. The deployment script [TODO] in this repo contains logic for the deployer to perform this initial deposit, so it is **HIGHLY RECOMMENDED** to use this deployment script when deploying the PSM. Reasoning for the technical implementation approach taken is outlined in more detail [here](https://github.com/marsfoundation/spark-psm/pull/2).

## Usage

```bash
Expand Down
132 changes: 81 additions & 51 deletions src/PSM.sol
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ interface IRateProviderLike {
// TODO: Refactor into inheritance structure
// TODO: Add interface with natspec and inherit
// TODO: Prove that we're always rounding against user
// TODO: Frontrunning attack, donation attack, virtual balances?
// TODO: Figure out how to optimize require checks for assets in view functions
// TODO: Discuss if we should add ERC20 functionality
// TODO: Add receiver to deposit/withdraw
contract PSM {

using SafeERC20 for IERC20;
Expand Down Expand Up @@ -80,11 +78,10 @@ contract PSM {
/*** Liquidity provision functions ***/
/**********************************************************************************************/

function deposit(address asset, uint256 assetsToDeposit) external {
require(asset == address(asset0) || asset == address(asset1), "PSM/invalid-asset");

// Convert amount to 1e18 precision denominated in value of asset0 then convert to shares.
uint256 newShares = convertToShares(_getAssetValue(asset, assetsToDeposit));
function deposit(address asset, uint256 assetsToDeposit)
external returns (uint256 newShares)
{
newShares = previewDeposit(asset, assetsToDeposit);

shares[msg.sender] += newShares;
totalShares += newShares;
Expand All @@ -94,6 +91,32 @@ contract PSM {

function withdraw(address asset, uint256 maxAssetsToWithdraw)
external returns (uint256 assetsWithdrawn)
{
uint256 sharesToBurn;

( sharesToBurn, assetsWithdrawn ) = previewWithdraw(asset, maxAssetsToWithdraw);

unchecked {
shares[msg.sender] -= sharesToBurn;
totalShares -= sharesToBurn;
}

IERC20(asset).safeTransfer(msg.sender, assetsWithdrawn);
}

/**********************************************************************************************/
/*** Deposit/withdraw preview functions ***/
/**********************************************************************************************/

function previewDeposit(address asset, uint256 assetsToDeposit) public view returns (uint256) {
require(asset == address(asset0) || asset == address(asset1), "PSM/invalid-asset");

// Convert amount to 1e18 precision denominated in value of asset0 then convert to shares.
return convertToShares(_getAssetValue(asset, assetsToDeposit));
}

function previewWithdraw(address asset, uint256 maxAssetsToWithdraw)
public view returns (uint256 sharesToBurn, uint256 assetsWithdrawn)
{
require(asset == address(asset0) || asset == address(asset1), "PSM/invalid-asset");

Expand All @@ -103,36 +126,43 @@ contract PSM {
? assetBalance
: maxAssetsToWithdraw;

uint256 sharesToBurn = convertToShares(asset, assetsWithdrawn);
sharesToBurn = _convertToSharesRoundUp(_getAssetValue(asset, assetsWithdrawn));

if (sharesToBurn > shares[msg.sender]) {
assetsWithdrawn = convertToAssets(asset, shares[msg.sender]);
sharesToBurn = convertToShares(asset, assetsWithdrawn);
}
uint256 userShares = shares[msg.sender];

unchecked {
shares[msg.sender] -= sharesToBurn;
totalShares -= sharesToBurn;
if (sharesToBurn > userShares) {
assetsWithdrawn = convertToAssets(asset, userShares);
sharesToBurn = userShares;
}

IERC20(asset).safeTransfer(msg.sender, assetsWithdrawn);
}

/**********************************************************************************************/
/*** Conversion functions ***/
/*** Swap preview functions ***/
/**********************************************************************************************/

function convertToShares(uint256 assetValue) public view returns (uint256) {
uint256 totalValue = getPsmTotalValue();
if (totalValue != 0) {
return assetValue * totalShares / totalValue;
}
return assetValue;
function previewSwapAssetZeroToOne(uint256 amountIn) public view returns (uint256) {
return amountIn
* 1e27
* asset1Precision
/ IRateProviderLike(rateProvider).getConversionRate()
/ asset0Precision;
}

function convertToShares(address asset, uint256 assets) public view returns (uint256) {
function previewSwapAssetOneToZero(uint256 amountIn) public view returns (uint256) {
return amountIn
* IRateProviderLike(rateProvider).getConversionRate()
* asset0Precision
/ 1e27
/ asset1Precision;
}

/**********************************************************************************************/
/*** Conversion functions ***/
/**********************************************************************************************/

function convertToAssets(address asset, uint256 numShares) public view returns (uint256) {
require(asset == address(asset0) || asset == address(asset1), "PSM/invalid-asset");
return convertToShares(_getAssetValue(asset, assets));
return _getAssetsByValue(asset, convertToAssetValue(numShares));
}

function convertToAssetValue(uint256 numShares) public view returns (uint256) {
Expand All @@ -144,9 +174,17 @@ contract PSM {
return numShares;
}

function convertToAssets(address asset, uint256 numShares) public view returns (uint256) {
function convertToShares(uint256 assetValue) public view returns (uint256) {
uint256 totalValue = getPsmTotalValue();
if (totalValue != 0) {
return assetValue * totalShares / totalValue;
}
return assetValue;
}

function convertToShares(address asset, uint256 assets) public view returns (uint256) {
require(asset == address(asset0) || asset == address(asset1), "PSM/invalid-asset");
return _getAssetsByValue(asset, convertToAssetValue(numShares));
return convertToShares(_getAssetValue(asset, assets));
}

/**********************************************************************************************/
Expand All @@ -159,35 +197,27 @@ contract PSM {
}

/**********************************************************************************************/
/*** Swap preview functions ***/
/*** Internal helper functions ***/
/**********************************************************************************************/

function previewSwapAssetZeroToOne(uint256 amountIn) public view returns (uint256) {
return amountIn
* 1e27
* asset1Precision
/ IRateProviderLike(rateProvider).getConversionRate()
/ asset0Precision;
function _convertToSharesRoundUp(uint256 assetValue) internal view returns (uint256) {
uint256 totalValue = getPsmTotalValue();
if (totalValue != 0) {
return _divUp(assetValue * totalShares, totalValue);
}
return assetValue;
}

function previewSwapAssetOneToZero(uint256 amountIn) public view returns (uint256) {
return amountIn
* IRateProviderLike(rateProvider).getConversionRate()
* asset0Precision
/ 1e27
/ asset1Precision;
function _divUp(uint256 x, uint256 y) internal pure returns (uint256 z) {
unchecked {
z = x != 0 ? ((x - 1) / y) + 1 : 0;
}
}

/**********************************************************************************************/
/*** Internal helper functions ***/
/**********************************************************************************************/

function _getAssetValue(address asset, uint256 amount) internal view returns (uint256) {
if (asset == address(asset0)) {
return _getAsset0Value(amount);
}

return _getAsset1Value(amount);
return asset == address(asset0)
? _getAsset0Value(amount)
: _getAsset1Value(amount);
}

function _getAssetsByValue(address asset, uint256 assetValue) internal view returns (uint256) {
Expand Down
111 changes: 111 additions & 0 deletions test/InflationAttack.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
pragma solidity ^0.8.13;

import "forge-std/Test.sol";

import { PSM } from "src/PSM.sol";

import { PSMTestBase } from "test/PSMTestBase.sol";

contract InflationAttackTests is PSMTestBase {

// TODO: Add DOS attack test outlined here: https://github.com/marsfoundation/spark-psm/pull/2#pullrequestreview-2085880206

function test_inflationAttack_noInitialDeposit() public {
psm = new PSM(address(usdc), address(sDai), address(rateProvider));

address firstDepositor = makeAddr("firstDepositor");
address frontRunner = makeAddr("frontRunner");

// Step 1: Front runner deposits 1 sDAI to get 1 share

// Have to use sDai because 1 USDC mints 1e12 shares
_deposit(frontRunner, address(sDai), 1);

assertEq(psm.shares(frontRunner), 1);

// Step 2: Front runner transfers 10m USDC to inflate the exchange rate to 1:(10m + 1)

deal(address(usdc), frontRunner, 10_000_000e6);

assertEq(psm.convertToAssetValue(1), 1);

vm.prank(frontRunner);
usdc.transfer(address(psm), 10_000_000e6);

// Highly inflated exchange rate
assertEq(psm.convertToAssetValue(1), 10_000_000e18 + 1);

// Step 3: First depositor deposits 20 million USDC, only gets one share because rounding
// error gives them 1 instead of 2 shares, worth 15m USDC

_deposit(firstDepositor, address(usdc), 20_000_000e6);

assertEq(psm.shares(firstDepositor), 1);

// 1 share = 3 million USDC / 2 shares = 1.5 million USDC
assertEq(psm.convertToAssetValue(1), 15_000_000e18);

// Step 4: Both users withdraw the max amount of funds they can

_withdraw(firstDepositor, address(usdc), type(uint256).max);
_withdraw(frontRunner, address(usdc), type(uint256).max);

assertEq(usdc.balanceOf(address(psm)), 0);

// Front runner profits 5m USDC, first depositor loses 5m USDC
assertEq(usdc.balanceOf(firstDepositor), 15_000_000e6);
assertEq(usdc.balanceOf(frontRunner), 15_000_000e6);
}

function test_inflationAttack_useInitialDeposit() public {
psm = new PSM(address(usdc), address(sDai), address(rateProvider));

address firstDepositor = makeAddr("firstDepositor");
address frontRunner = makeAddr("frontRunner");
address deployer = address(this); // TODO: Update to use non-deployer receiver

_deposit(address(this), address(sDai), 0.8e18); /// 1e18 shares

// Step 1: Front runner deposits sDAI to get 1 share

// User tries to do the same attack, depositing one sDAI for 1 share
_deposit(frontRunner, address(sDai), 1);

assertEq(psm.shares(frontRunner), 1);

// Step 2: Front runner transfers 10m USDC to inflate the exchange rate to 1:(10m + 1)

assertEq(psm.convertToAssetValue(1), 1);

deal(address(usdc), frontRunner, 10_000_000e6);

vm.prank(frontRunner);
usdc.transfer(address(psm), 10_000_000e6);

// Still inflated, but all value is transferred to existing holder, deployer
assertEq(psm.convertToAssetValue(1), 0.00000000001e18);

// Step 3: First depositor deposits 20 million USDC, this time rounding is not an issue
// so value reflected is much more accurate

_deposit(firstDepositor, address(usdc), 20_000_000e6);

assertEq(psm.shares(firstDepositor), 1.999999800000020001e18);

// Higher amount of initial shares means lower rounding error
assertEq(psm.convertToAssetValue(1.999999800000020001e18), 19_999_999.999999999996673334e18);

// Step 4: Both users withdraw the max amount of funds they can

_withdraw(firstDepositor, address(usdc), type(uint256).max);
_withdraw(frontRunner, address(usdc), type(uint256).max);
_withdraw(deployer, address(usdc), type(uint256).max);

// Front runner loses full 10m USDC to the deployer that had all shares at the beginning, first depositor loses nothing (1e-6 USDC)
assertEq(usdc.balanceOf(firstDepositor), 19_999_999.999999e6);
assertEq(usdc.balanceOf(frontRunner), 0);
assertEq(usdc.balanceOf(deployer), 10_000_000.000001e6);
}

}
2 changes: 1 addition & 1 deletion test/PSMTestBase.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pragma solidity ^0.8.13;

import "forge-std/Test.sol";

import { PSM } from "../src/PSM.sol";
import { PSM } from "src/PSM.sol";

import { MockERC20 } from "erc20-helpers/MockERC20.sol";

Expand Down
Loading
Loading