|
27 | 27 |
|
28 | 28 | _T = TypeVar("_T") |
29 | 29 |
|
30 | | -if sys.version_info[:2] > (3, 7): |
| 30 | +if sys.version_info >= (3, 8): |
31 | 31 | AsyncMockType = unittest.mock.AsyncMock |
| 32 | + MockType = Union[ |
| 33 | + unittest.mock.MagicMock, |
| 34 | + unittest.mock.AsyncMock, |
| 35 | + unittest.mock.NonCallableMagicMock, |
| 36 | + ] |
32 | 37 | else: |
33 | 38 | AsyncMockType = Any |
| 39 | + MockType = Union[unittest.mock.MagicMock, unittest.mock.NonCallableMagicMock] |
34 | 40 |
|
35 | 41 |
|
36 | 42 | class PytestMockWarning(UserWarning): |
@@ -112,7 +118,7 @@ def stop(self, mock: unittest.mock.MagicMock) -> None: |
112 | 118 | else: |
113 | 119 | raise ValueError("This mock object is not registered") |
114 | 120 |
|
115 | | - def spy(self, obj: object, name: str) -> unittest.mock.MagicMock: |
| 121 | + def spy(self, obj: object, name: str) -> MockType: |
116 | 122 | """ |
117 | 123 | Create a spy of method. It will run method normally, but it is now |
118 | 124 | possible to use `mock` call features with it, like call count. |
@@ -205,13 +211,13 @@ def __init__(self, patches_and_mocks, mock_module): |
205 | 211 |
|
206 | 212 | def _start_patch( |
207 | 213 | self, mock_func: Any, warn_on_mock_enter: bool, *args: Any, **kwargs: Any |
208 | | - ) -> unittest.mock.MagicMock: |
| 214 | + ) -> MockType: |
209 | 215 | """Patches something by calling the given function from the mock |
210 | 216 | module, registering the patch to stop it later and returns the |
211 | 217 | mock object resulting from the mock call. |
212 | 218 | """ |
213 | 219 | p = mock_func(*args, **kwargs) |
214 | | - mocked = p.start() # type: unittest.mock.MagicMock |
| 220 | + mocked: MockType = p.start() |
215 | 221 | self.__patches_and_mocks.append((p, mocked)) |
216 | 222 | if hasattr(mocked, "reset_mock"): |
217 | 223 | # check if `mocked` is actually a mock object, as depending on autospec or target |
@@ -242,7 +248,7 @@ def object( |
242 | 248 | autospec: Optional[object] = None, |
243 | 249 | new_callable: object = None, |
244 | 250 | **kwargs: Any |
245 | | - ) -> unittest.mock.MagicMock: |
| 251 | + ) -> MockType: |
246 | 252 | """API to mock.patch.object""" |
247 | 253 | if new is self.DEFAULT: |
248 | 254 | new = self.mock_module.DEFAULT |
@@ -271,7 +277,7 @@ def context_manager( |
271 | 277 | autospec: Optional[builtins.object] = None, |
272 | 278 | new_callable: builtins.object = None, |
273 | 279 | **kwargs: Any |
274 | | - ) -> unittest.mock.MagicMock: |
| 280 | + ) -> MockType: |
275 | 281 | """This is equivalent to mock.patch.object except that the returned mock |
276 | 282 | does not issue a warning when used as a context manager.""" |
277 | 283 | if new is self.DEFAULT: |
@@ -299,7 +305,7 @@ def multiple( |
299 | 305 | autospec: Optional[builtins.object] = None, |
300 | 306 | new_callable: Optional[builtins.object] = None, |
301 | 307 | **kwargs: Any |
302 | | - ) -> Dict[str, unittest.mock.MagicMock]: |
| 308 | + ) -> Dict[str, MockType]: |
303 | 309 | """API to mock.patch.multiple""" |
304 | 310 | return self._start_patch( |
305 | 311 | self.mock_module.patch.multiple, |
@@ -341,7 +347,7 @@ def __call__( |
341 | 347 | autospec: Optional[builtins.object] = ..., |
342 | 348 | new_callable: None = ..., |
343 | 349 | **kwargs: Any |
344 | | - ) -> unittest.mock.MagicMock: |
| 350 | + ) -> MockType: |
345 | 351 | ... |
346 | 352 |
|
347 | 353 | @overload |
|
0 commit comments