@@ -234,6 +234,8 @@ class _patch(Generic[_T]):
234
234
def copy (self ) -> _patch [_T ]: ...
235
235
@overload
236
236
def __call__ (self , func : _TT ) -> _TT : ...
237
+ # If new==DEFAULT, this should add a MagicMock parameter to the function
238
+ # arguments. See the _patch_default_new class below for this functionality.
237
239
@overload
238
240
def __call__ (self , func : Callable [_P , _R ]) -> Callable [_P , _R ]: ...
239
241
if sys .version_info >= (3 , 8 ):
@@ -257,6 +259,22 @@ class _patch(Generic[_T]):
257
259
def start (self ) -> _T : ...
258
260
def stop (self ) -> None : ...
259
261
262
+ if sys .version_info >= (3 , 8 ):
263
+ _Mock : TypeAlias = MagicMock | AsyncMock
264
+ else :
265
+ _Mock : TypeAlias = MagicMock
266
+
267
+ # This class does not exist at runtime, it's a hack to make this work:
268
+ # @patch("foo")
269
+ # def bar(..., mock: MagicMock) -> None: ...
270
+ class _patch_default_new (_patch [_Mock ]):
271
+ @overload
272
+ def __call__ (self , func : _TT ) -> _TT : ...
273
+ # Can't use the following as ParamSpec is only allowed as last parameter:
274
+ # def __call__(self, func: Callable[_P, _R]) -> Callable[Concatenate[_P, MagicMock], _R]: ...
275
+ @overload
276
+ def __call__ (self , func : Callable [..., _R ]) -> Callable [..., _R ]: ...
277
+
260
278
class _patch_dict :
261
279
in_dict : Any
262
280
values : Any
@@ -273,11 +291,8 @@ class _patch_dict:
273
291
start : Any
274
292
stop : Any
275
293
276
- if sys .version_info >= (3 , 8 ):
277
- _Mock : TypeAlias = MagicMock | AsyncMock
278
- else :
279
- _Mock : TypeAlias = MagicMock
280
-
294
+ # This class does not exist at runtime, it's a hack to add methods to the
295
+ # patch() function.
281
296
class _patcher :
282
297
TEST_PREFIX : str
283
298
dict : type [_patch_dict ]
@@ -307,7 +322,7 @@ class _patcher:
307
322
autospec : Any | None = ...,
308
323
new_callable : Any | None = ...,
309
324
** kwargs : Any ,
310
- ) -> _patch [ _Mock ] : ...
325
+ ) -> _patch_default_new : ...
311
326
@overload
312
327
@staticmethod
313
328
def object ( # type: ignore[misc]
0 commit comments