Skip to content

Commit 46b954e

Browse files
authored
🖼 Mock contract chaining behaviour (#816)
1 parent fb6863d commit 46b954e

File tree

8 files changed

+389
-45
lines changed

8 files changed

+389
-45
lines changed

‎.changeset/orange-deers-sit.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@ethereum-waffle/mock-contract": patch
3+
---
4+
5+
Add mock contract chaining behaviour

‎docs/source/mock-contract.rst

+49
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,55 @@ Mock contract will be used to mock exactly this call with values that are releva
123123
});
124124
});
125125
126+
Mocking multiple calls
127+
----------------------
128+
129+
Mock contract allows to queue multiple mock calls to the same function. This can only be done if the function is not pure or view. That's because the mock call queue is stored on the blockchain and we need to modify it.
130+
131+
.. code-block:: ts
132+
133+
await mockContract.mock.<nameOfMethod>.returns(<value1>).returns(<value2>);
134+
135+
await mockContract.<nameOfMethod>() // returns <value1>
136+
await mockContract.<nameOfMethod>() // returns <value2>
137+
138+
Just like with regular mock calls, the queue can be set up to revert or return a specified value. It can also be set up to return different values for different arguments.
139+
140+
.. code-block:: ts
141+
142+
await mockContract.mock.<nameOfMethod>.returns(<value1>).returns(<value2>);
143+
await mockContract.mock.<nameOfMethod>.withArgs(<arguments1>).returns(<value3>);
144+
145+
await mockContract.<nameOfMethod>() // returns <value1>
146+
await mockContract.<nameOfMethod>() // returns <value2>
147+
await mockContract.<nameOfMethod>(<arguments1>) // returns <value3>
148+
149+
Keep in mind that the mocked revert must be at the end of the queue, because it prevents the contract from updating the queue.
150+
151+
.. code-block:: ts
152+
153+
await mockContract.mock.<nameOfMethod>.returns(<value1>).returns(<value2>).reverts();
154+
155+
await mockContract.<nameOfMethod>() // returns <value1>
156+
await mockContract.<nameOfMethod>() // returns <value2>
157+
await mockContract.<nameOfMethod>() // reverts
158+
159+
When the queue is empty, the mock contract will return the last value from the queue and each time the you set up a new queue, the old one is overwritten.
160+
161+
.. code-block:: ts
162+
163+
await mockContract.mock.<nameOfMethod>.returns(<value1>).returns(<value2>);
164+
165+
await mockContract.<nameOfMethod>() // returns <value1>
166+
await mockContract.<nameOfMethod>() // returns <value2>
167+
await mockContract.<nameOfMethod>() // returns <value2>
168+
169+
await mockContract.mock.<nameOfMethod>.returns(<value1>).returns(<value2>);
170+
await mockContract.mock.<nameOfMethod>.returns(<value3>).returns(<value4>);
171+
172+
await mockContract.<nameOfMethod>() // returns <value3>
173+
await mockContract.<nameOfMethod>() // returns <value4>
174+
126175
Mocking receive function
127176
------------------------
128177

‎waffle-mock-contract/package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"module": "dist/esm/src/index.ts",
3333
"types": "dist/esm/src/index.d.ts",
3434
"scripts": {
35-
"test": "export NODE_ENV=test && mocha",
35+
"test": "ts-node ./test/helpers/buildTestContracts.ts && export NODE_ENV=test && mocha",
3636
"lint": "eslint '{src,test}/**/*.ts'",
3737
"lint:fix": "eslint --fix '{src,test}/**/*.ts'",
3838
"build": "rimraf ./dist && yarn build:sol && yarn build:esm && yarn build:cjs && ts-node ./test/helpers/buildTestContracts.ts",

‎waffle-mock-contract/src/Doppelganger.sol

+83-15
Original file line numberDiff line numberDiff line change
@@ -2,48 +2,101 @@
22
pragma solidity ^0.6.3;
33

44
contract Doppelganger {
5+
6+
// ============================== Linked list queues data structure explainer ==============================
7+
// mockConfig contains multiple linked lists, one for each unique call
8+
// mockConfig[<callData hash>] => root node of the linked list for this call
9+
// mockConfig[<callData hash>].next => 'address' of the next node. It's always defined, even if it's the last node.
10+
// When defining a new node .next is set to the hash of the 'address' of the last node
11+
// mockConfig[mockConfig[<callData hash>].next] => next node (possibly undefined)
12+
// tails[<callData hash>] => 'address' of the node 'one after' the last node (<last node>.next)
13+
// in the linked list with root node at <callData hash>
14+
515
struct MockCall {
6-
bool initialized;
16+
bytes32 next;
717
bool reverts;
818
string revertReason;
919
bytes returnValue;
1020
}
11-
1221
mapping(bytes32 => MockCall) mockConfig;
22+
mapping(bytes32 => bytes32) tails;
1323
bool receiveReverts;
1424
string receiveRevertReason;
1525

1626
fallback() external payable {
17-
MockCall storage mockCall = __internal__getMockCall();
27+
MockCall memory mockCall = __internal__getMockCall();
1828
if (mockCall.reverts == true) {
1929
__internal__mockRevert(mockCall.revertReason);
2030
return;
2131
}
2232
__internal__mockReturn(mockCall.returnValue);
2333
}
24-
34+
2535
receive() payable external {
2636
require(receiveReverts == false, receiveRevertReason);
2737
}
38+
39+
function __clearQueue(bytes32 at) private {
40+
tails[at] = at;
41+
while(mockConfig[at].next != "") {
42+
bytes32 next = mockConfig[at].next;
43+
delete mockConfig[at];
44+
at = next;
45+
}
46+
}
2847

29-
function __waffle__mockReverts(bytes memory data, string memory reason) public {
30-
mockConfig[keccak256(data)] = MockCall({
31-
initialized: true,
48+
function __waffle__queueRevert(bytes memory data, string memory reason) public {
49+
// get the root node of the linked list for this call
50+
bytes32 root = keccak256(data);
51+
52+
// get the 'address' of the node 'one after' the last node
53+
// this is where the new node will be inserted
54+
bytes32 tail = tails[root];
55+
if(tail == "") tail = keccak256(data);
56+
57+
// new tail is set to the hash of the current tail
58+
tails[root] = keccak256(abi.encodePacked(tail));
59+
60+
// initialize the new node
61+
mockConfig[tail] = MockCall({
62+
next: tails[root],
3263
reverts: true,
3364
revertReason: reason,
3465
returnValue: ""
3566
});
3667
}
3768

38-
function __waffle__mockReturns(bytes memory data, bytes memory value) public {
39-
mockConfig[keccak256(data)] = MockCall({
40-
initialized: true,
69+
function __waffle__mockReverts(bytes memory data, string memory reason) public {
70+
__clearQueue(keccak256(data));
71+
__waffle__queueRevert(data, reason);
72+
}
73+
74+
function __waffle__queueReturn(bytes memory data, bytes memory value) public {
75+
// get the root node of the linked list for this call
76+
bytes32 root = keccak256(data);
77+
78+
// get the 'address' of the node 'one after' the last node
79+
// this is where the new node will be inserted
80+
bytes32 tail = tails[root];
81+
if(tail == "") tail = keccak256(data);
82+
83+
// new tail is set to the hash of the current tail
84+
tails[root] = keccak256(abi.encodePacked(tail));
85+
86+
// initialize the new node
87+
mockConfig[tail] = MockCall({
88+
next: tails[root],
4189
reverts: false,
4290
revertReason: "",
4391
returnValue: value
4492
});
4593
}
4694

95+
function __waffle__mockReturns(bytes memory data, bytes memory value) public {
96+
__clearQueue(keccak256(data));
97+
__waffle__queueReturn(data, value);
98+
}
99+
47100
function __waffle__receiveReverts(string memory reason) public {
48101
receiveReverts = true;
49102
receiveRevertReason = reason;
@@ -61,15 +114,30 @@ contract Doppelganger {
61114
return returnValue;
62115
}
63116

64-
function __internal__getMockCall() view private returns (MockCall storage mockCall) {
65-
mockCall = mockConfig[keccak256(msg.data)];
66-
if (mockCall.initialized == true) {
117+
function __internal__getMockCall() private returns (MockCall memory mockCall) {
118+
// get the root node of the queue for this call
119+
bytes32 root = keccak256(msg.data);
120+
mockCall = mockConfig[root];
121+
if (mockCall.next != "") {
67122
// Mock method with specified arguments
123+
124+
// If there is a next mock call, set it as the current mock call
125+
// We check if the next mock call is defined by checking if it has a 'next' variable defined
126+
// (next value is always defined, even if it's the last mock call)
127+
if(mockConfig[mockCall.next].next != ""){ // basically if it's not the last mock call
128+
mockConfig[root] = mockConfig[mockCall.next];
129+
delete mockConfig[mockCall.next];
130+
}
68131
return mockCall;
69132
}
70-
mockCall = mockConfig[keccak256(abi.encodePacked(msg.sig))];
71-
if (mockCall.initialized == true) {
133+
root = keccak256(abi.encodePacked(msg.sig));
134+
mockCall = mockConfig[root];
135+
if (mockCall.next != "") {
72136
// Mock method with any arguments
137+
if(mockConfig[mockCall.next].next != ""){ // same as above
138+
mockConfig[root] = mockConfig[mockCall.next];
139+
delete mockConfig[mockCall.next];
140+
}
73141
return mockCall;
74142
}
75143
revert("Mock on the method is not initialized");

‎waffle-mock-contract/src/index.ts

+115-27
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,126 @@ import type {JsonRpcProvider} from '@ethersproject/providers';
66

77
type ABI = string | Array<utils.Fragment | JsonFragment | string>
88

9-
export type Stub = ReturnType<typeof stub>;
10-
11-
type DeployOptions = {
12-
address: string;
13-
override?: boolean;
9+
interface StubInterface {
10+
returns(...args: any): StubInterface;
11+
reverts(): StubInterface;
12+
revertsWithReason(reason: string): StubInterface;
13+
withArgs(...args: any[]): StubInterface;
1414
}
1515

1616
export interface MockContract extends Contract {
1717
mock: {
18-
[key: string]: Stub;
18+
[key: string]: StubInterface;
1919
};
2020
call (contract: Contract, functionName: string, ...params: any[]): Promise<any>;
2121
staticcall (contract: Contract, functionName: string, ...params: any[]): Promise<any>;
2222
}
2323

24+
class Stub implements StubInterface {
25+
callData: string;
26+
stubCalls: Array<() => Promise<any>> = [];
27+
revertSet = false;
28+
argsSet = false;
29+
30+
constructor(
31+
private mockContract: Contract,
32+
private encoder: utils.AbiCoder,
33+
private func: utils.FunctionFragment
34+
) {
35+
this.callData = mockContract.interface.getSighash(func);
36+
}
37+
38+
private err(reason: string): never {
39+
this.stubCalls = [];
40+
this.revertSet = false;
41+
this.argsSet = false;
42+
throw new Error(reason);
43+
}
44+
45+
returns(...args: any) {
46+
if (this.revertSet) this.err('Revert must be the last call');
47+
if (!this.func.outputs) this.err('Cannot mock return values from a void function');
48+
const encoded = this.encoder.encode(this.func.outputs, args);
49+
50+
// if there no calls then this is the first call and we need to use mockReturns to override the queue
51+
if (this.stubCalls.length === 0) {
52+
this.stubCalls.push(async () => {
53+
await this.mockContract.__waffle__mockReturns(this.callData, encoded);
54+
});
55+
} else {
56+
this.stubCalls.push(async () => {
57+
await this.mockContract.__waffle__queueReturn(this.callData, encoded);
58+
});
59+
}
60+
return this;
61+
}
62+
63+
reverts() {
64+
if (this.revertSet) this.err('Revert must be the last call');
65+
66+
// if there no calls then this is the first call and we need to use mockReturns to override the queue
67+
if (this.stubCalls.length === 0) {
68+
this.stubCalls.push(async () => {
69+
await this.mockContract.__waffle__mockReverts(this.callData, 'Mock revert');
70+
});
71+
} else {
72+
this.stubCalls.push(async () => {
73+
await this.mockContract.__waffle__queueRevert(this.callData, 'Mock revert');
74+
});
75+
}
76+
this.revertSet = true;
77+
return this;
78+
}
79+
80+
revertsWithReason(reason: string) {
81+
if (this.revertSet) this.err('Revert must be the last call');
82+
83+
// if there no calls then this is the first call and we need to use mockReturns to override the queue
84+
if (this.stubCalls.length === 0) {
85+
this.stubCalls.push(async () => {
86+
await this.mockContract.__waffle__mockReverts(this.callData, reason);
87+
});
88+
} else {
89+
this.stubCalls.push(async () => {
90+
await this.mockContract.__waffle__queueRevert(this.callData, reason);
91+
});
92+
}
93+
this.revertSet = true;
94+
return this;
95+
}
96+
97+
withArgs(...params: any[]) {
98+
if (this.argsSet) this.err('withArgs can be called only once');
99+
this.callData = this.mockContract.interface.encodeFunctionData(this.func, params);
100+
this.argsSet = true;
101+
return this;
102+
}
103+
104+
async then(resolve: () => void, reject: (e: any) => void) {
105+
for (let i = 0; i < this.stubCalls.length; i++) {
106+
try {
107+
await this.stubCalls[i]();
108+
} catch (e) {
109+
this.stubCalls = [];
110+
this.argsSet = false;
111+
this.revertSet = false;
112+
reject(e);
113+
return;
114+
}
115+
}
116+
117+
this.stubCalls = [];
118+
this.argsSet = false;
119+
this.revertSet = false;
120+
resolve();
121+
}
122+
}
123+
124+
type DeployOptions = {
125+
address: string;
126+
override?: boolean;
127+
}
128+
24129
async function deploy(signer: Signer, options?: DeployOptions) {
25130
if (options) {
26131
const {address, override} = options;
@@ -50,29 +155,12 @@ async function deploy(signer: Signer, options?: DeployOptions) {
50155
return factory.deploy();
51156
}
52157

53-
function stub(mockContract: Contract, encoder: utils.AbiCoder, func: utils.FunctionFragment, params?: any[]) {
54-
const callData = params
55-
? mockContract.interface.encodeFunctionData(func, params)
56-
: mockContract.interface.getSighash(func);
57-
58-
return {
59-
returns: async (...args: any) => {
60-
if (!func.outputs) return;
61-
const encoded = encoder.encode(func.outputs, args);
62-
await mockContract.__waffle__mockReturns(callData, encoded);
63-
},
64-
reverts: async () => mockContract.__waffle__mockReverts(callData, 'Mock revert'),
65-
revertsWithReason: async (reason: string) => mockContract.__waffle__mockReverts(callData, reason),
66-
withArgs: (...args: any[]) => stub(mockContract, encoder, func, args)
67-
};
68-
}
69-
70158
function createMock(abi: ABI, mockContractInstance: Contract) {
71159
const {functions} = new utils.Interface(abi);
72160
const encoder = new utils.AbiCoder();
73161

74162
const mockedAbi = Object.values(functions).reduce((acc, func) => {
75-
const stubbed = stub(mockContractInstance, encoder, func);
163+
const stubbed = new Stub(mockContractInstance as MockContract, encoder, func);
76164
return {
77165
...acc,
78166
[func.name]: stubbed,
@@ -81,10 +169,10 @@ function createMock(abi: ABI, mockContractInstance: Contract) {
81169
}, {} as MockContract['mock']);
82170

83171
mockedAbi.receive = {
84-
returns: async () => { throw new Error('Receive function return is not implemented.'); },
172+
returns: () => { throw new Error('Receive function return is not implemented.'); },
85173
withArgs: () => { throw new Error('Receive function return is not implemented.'); },
86-
reverts: async () => mockContractInstance.__waffle__receiveReverts('Mock Revert'),
87-
revertsWithReason: async (reason: string) => mockContractInstance.__waffle__receiveReverts(reason)
174+
reverts: () => mockContractInstance.__waffle__receiveReverts('Mock Revert'),
175+
revertsWithReason: (reason: string) => mockContractInstance.__waffle__receiveReverts(reason)
88176
};
89177

90178
return mockedAbi;

0 commit comments

Comments
 (0)