@@ -6,21 +6,126 @@ import type {JsonRpcProvider} from '@ethersproject/providers';
6
6
7
7
type ABI = string | Array < utils . Fragment | JsonFragment | string >
8
8
9
- export type Stub = ReturnType < typeof stub > ;
10
-
11
- type DeployOptions = {
12
- address : string ;
13
- override ?: boolean ;
9
+ interface StubInterface {
10
+ returns ( ... args : any ) : StubInterface ;
11
+ reverts ( ) : StubInterface ;
12
+ revertsWithReason ( reason : string ) : StubInterface ;
13
+ withArgs ( ... args : any [ ] ) : StubInterface ;
14
14
}
15
15
16
16
export interface MockContract extends Contract {
17
17
mock : {
18
- [ key : string ] : Stub ;
18
+ [ key : string ] : StubInterface ;
19
19
} ;
20
20
call ( contract : Contract , functionName : string , ...params : any [ ] ) : Promise < any > ;
21
21
staticcall ( contract : Contract , functionName : string , ...params : any [ ] ) : Promise < any > ;
22
22
}
23
23
24
+ class Stub implements StubInterface {
25
+ callData : string ;
26
+ stubCalls : Array < ( ) => Promise < any > > = [ ] ;
27
+ revertSet = false ;
28
+ argsSet = false ;
29
+
30
+ constructor (
31
+ private mockContract : Contract ,
32
+ private encoder : utils . AbiCoder ,
33
+ private func : utils . FunctionFragment
34
+ ) {
35
+ this . callData = mockContract . interface . getSighash ( func ) ;
36
+ }
37
+
38
+ private err ( reason : string ) : never {
39
+ this . stubCalls = [ ] ;
40
+ this . revertSet = false ;
41
+ this . argsSet = false ;
42
+ throw new Error ( reason ) ;
43
+ }
44
+
45
+ returns ( ...args : any ) {
46
+ if ( this . revertSet ) this . err ( 'Revert must be the last call' ) ;
47
+ if ( ! this . func . outputs ) this . err ( 'Cannot mock return values from a void function' ) ;
48
+ const encoded = this . encoder . encode ( this . func . outputs , args ) ;
49
+
50
+ // if there no calls then this is the first call and we need to use mockReturns to override the queue
51
+ if ( this . stubCalls . length === 0 ) {
52
+ this . stubCalls . push ( async ( ) => {
53
+ await this . mockContract . __waffle__mockReturns ( this . callData , encoded ) ;
54
+ } ) ;
55
+ } else {
56
+ this . stubCalls . push ( async ( ) => {
57
+ await this . mockContract . __waffle__queueReturn ( this . callData , encoded ) ;
58
+ } ) ;
59
+ }
60
+ return this ;
61
+ }
62
+
63
+ reverts ( ) {
64
+ if ( this . revertSet ) this . err ( 'Revert must be the last call' ) ;
65
+
66
+ // if there no calls then this is the first call and we need to use mockReturns to override the queue
67
+ if ( this . stubCalls . length === 0 ) {
68
+ this . stubCalls . push ( async ( ) => {
69
+ await this . mockContract . __waffle__mockReverts ( this . callData , 'Mock revert' ) ;
70
+ } ) ;
71
+ } else {
72
+ this . stubCalls . push ( async ( ) => {
73
+ await this . mockContract . __waffle__queueRevert ( this . callData , 'Mock revert' ) ;
74
+ } ) ;
75
+ }
76
+ this . revertSet = true ;
77
+ return this ;
78
+ }
79
+
80
+ revertsWithReason ( reason : string ) {
81
+ if ( this . revertSet ) this . err ( 'Revert must be the last call' ) ;
82
+
83
+ // if there no calls then this is the first call and we need to use mockReturns to override the queue
84
+ if ( this . stubCalls . length === 0 ) {
85
+ this . stubCalls . push ( async ( ) => {
86
+ await this . mockContract . __waffle__mockReverts ( this . callData , reason ) ;
87
+ } ) ;
88
+ } else {
89
+ this . stubCalls . push ( async ( ) => {
90
+ await this . mockContract . __waffle__queueRevert ( this . callData , reason ) ;
91
+ } ) ;
92
+ }
93
+ this . revertSet = true ;
94
+ return this ;
95
+ }
96
+
97
+ withArgs ( ...params : any [ ] ) {
98
+ if ( this . argsSet ) this . err ( 'withArgs can be called only once' ) ;
99
+ this . callData = this . mockContract . interface . encodeFunctionData ( this . func , params ) ;
100
+ this . argsSet = true ;
101
+ return this ;
102
+ }
103
+
104
+ async then ( resolve : ( ) => void , reject : ( e : any ) => void ) {
105
+ for ( let i = 0 ; i < this . stubCalls . length ; i ++ ) {
106
+ try {
107
+ await this . stubCalls [ i ] ( ) ;
108
+ } catch ( e ) {
109
+ this . stubCalls = [ ] ;
110
+ this . argsSet = false ;
111
+ this . revertSet = false ;
112
+ reject ( e ) ;
113
+ return ;
114
+ }
115
+ }
116
+
117
+ this . stubCalls = [ ] ;
118
+ this . argsSet = false ;
119
+ this . revertSet = false ;
120
+ resolve ( ) ;
121
+ }
122
+ }
123
+
124
+ type DeployOptions = {
125
+ address : string ;
126
+ override ?: boolean ;
127
+ }
128
+
24
129
async function deploy ( signer : Signer , options ?: DeployOptions ) {
25
130
if ( options ) {
26
131
const { address, override} = options ;
@@ -50,29 +155,12 @@ async function deploy(signer: Signer, options?: DeployOptions) {
50
155
return factory . deploy ( ) ;
51
156
}
52
157
53
- function stub ( mockContract : Contract , encoder : utils . AbiCoder , func : utils . FunctionFragment , params ?: any [ ] ) {
54
- const callData = params
55
- ? mockContract . interface . encodeFunctionData ( func , params )
56
- : mockContract . interface . getSighash ( func ) ;
57
-
58
- return {
59
- returns : async ( ...args : any ) => {
60
- if ( ! func . outputs ) return ;
61
- const encoded = encoder . encode ( func . outputs , args ) ;
62
- await mockContract . __waffle__mockReturns ( callData , encoded ) ;
63
- } ,
64
- reverts : async ( ) => mockContract . __waffle__mockReverts ( callData , 'Mock revert' ) ,
65
- revertsWithReason : async ( reason : string ) => mockContract . __waffle__mockReverts ( callData , reason ) ,
66
- withArgs : ( ...args : any [ ] ) => stub ( mockContract , encoder , func , args )
67
- } ;
68
- }
69
-
70
158
function createMock ( abi : ABI , mockContractInstance : Contract ) {
71
159
const { functions} = new utils . Interface ( abi ) ;
72
160
const encoder = new utils . AbiCoder ( ) ;
73
161
74
162
const mockedAbi = Object . values ( functions ) . reduce ( ( acc , func ) => {
75
- const stubbed = stub ( mockContractInstance , encoder , func ) ;
163
+ const stubbed = new Stub ( mockContractInstance as MockContract , encoder , func ) ;
76
164
return {
77
165
...acc ,
78
166
[ func . name ] : stubbed ,
@@ -81,10 +169,10 @@ function createMock(abi: ABI, mockContractInstance: Contract) {
81
169
} , { } as MockContract [ 'mock' ] ) ;
82
170
83
171
mockedAbi . receive = {
84
- returns : async ( ) => { throw new Error ( 'Receive function return is not implemented.' ) ; } ,
172
+ returns : ( ) => { throw new Error ( 'Receive function return is not implemented.' ) ; } ,
85
173
withArgs : ( ) => { throw new Error ( 'Receive function return is not implemented.' ) ; } ,
86
- reverts : async ( ) => mockContractInstance . __waffle__receiveReverts ( 'Mock Revert' ) ,
87
- revertsWithReason : async ( reason : string ) => mockContractInstance . __waffle__receiveReverts ( reason )
174
+ reverts : ( ) => mockContractInstance . __waffle__receiveReverts ( 'Mock Revert' ) ,
175
+ revertsWithReason : ( reason : string ) => mockContractInstance . __waffle__receiveReverts ( reason )
88
176
} ;
89
177
90
178
return mockedAbi ;
0 commit comments