diff --git a/contracts/receivers/aave-v2-receiver/main.sol b/contracts/receivers/aave-v2-receiver/main.sol index 292bb62..0d15fde 100644 --- a/contracts/receivers/aave-v2-receiver/main.sol +++ b/contracts/receivers/aave-v2-receiver/main.sol @@ -4,9 +4,10 @@ pragma experimental ABIEncoderV2; import { SafeERC20 } from "@openzeppelin/contracts/token/ERC20/SafeERC20.sol"; import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import { DSMath } from "../../common/math.sol"; import { AccountInterface } from "./interfaces.sol"; -contract MigrateResolver { +contract MigrateResolver is DSMath { using SafeERC20 for IERC20; struct AaveData { @@ -21,6 +22,7 @@ contract MigrateResolver { uint private lastStateId; mapping (address => AaveData) public positions; + mapping(address => mapping(address => uint)) deposits; function _migratePosition(address owner) internal { AaveData storage data = positions[owner]; @@ -33,6 +35,40 @@ contract MigrateResolver { data.isFinal = true; } + function deposit(address[] calldata tokens, uint[] calldata amts) external { + uint _length = tokens.length; + require(_length == amts.length, "invalid-length"); + + for (uint256 i = 0; i < _length; i++) { + address _token = tokens[i]; + + IERC20 tokenContract = IERC20(_token); + uint _amt = amts[i] == uint(-1) ? tokenContract.balanceOf(msg.sender) : amts[i]; + tokenContract.safeTransferFrom(msg.sender, address(this), _amt); + + deposits[msg.sender][_token] = _amt; + } + } + + function withdraw(address[] calldata tokens, uint[] calldata amts) external { + uint _length = tokens.length; + require(_length == amts.length, "invalid-length"); + + for (uint256 i = 0; i < _length; i++) { + uint _amt = amts[i]; + address _token = tokens[i]; + uint maxAmt = deposits[msg.sender][_token]; + + if (_amt > maxAmt) { + _amt = maxAmt; + } + + IERC20(_token).safeTransfer(msg.sender, _amt); + + deposits[msg.sender][_token] = sub(maxAmt, _amt); + } + } + function onStateReceive(uint256 stateId, bytes calldata receivedData) external { // require(stateId > lastStateId, "wrong-data"); lastStateId = stateId; diff --git a/contracts/senders/aave-v2-migrator/main.sol b/contracts/senders/aave-v2-migrator/main.sol index 027a424..4a31415 100644 --- a/contracts/senders/aave-v2-migrator/main.sol +++ b/contracts/senders/aave-v2-migrator/main.sol @@ -4,6 +4,7 @@ pragma experimental ABIEncoderV2; import { SafeERC20 } from "@openzeppelin/contracts/token/ERC20/SafeERC20.sol"; import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import { TokenInterface } from "../../common/interfaces.sol"; import { Helpers } from "./helpers.sol"; import { AaveInterface, ATokenInterface } from "./interfaces.sol"; @@ -22,13 +23,39 @@ contract LiquidityResolver is Helpers { if (_token == ethAddr) { require(msg.value == amts[i]); _amt = msg.value; + + TokenInterface(wethAddr).deposit{value: msg.value}(); } else { IERC20 tokenContract = IERC20(_token); _amt = amts[i] == uint(-1) ? tokenContract.balanceOf(msg.sender) : amts[i]; tokenContract.safeTransferFrom(msg.sender, address(this), _amt); } - deposits[_token][msg.sender] = _amt; + deposits[msg.sender][_token] = _amt; + } + } + + function withdraw(address[] calldata tokens, uint[] calldata amts) external { + uint _length = tokens.length; + require(_length == amts.length, "invalid-length"); + + for (uint256 i = 0; i < _length; i++) { + uint _amt = amts[i]; + address _token = tokens[i]; + uint maxAmt = deposits[msg.sender][_token]; + + if (_amt > maxAmt) { + _amt = maxAmt; + } + + if (_token == ethAddr) { + TokenInterface(wethAddr).withdraw(_amt); + msg.sender.call{value: _amt}(""); + } else { + IERC20(_token).safeTransfer(msg.sender, _amt); + } + + deposits[msg.sender][_token] = sub(maxAmt, _amt); } } }