Skip to content

Commit 9fddad0

Browse files
committed
Update type hints for wrapper methods, adding Self
Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu>
1 parent 25e5860 commit 9fddad0

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

kedro/io/core.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from cachetools import Cache, cachedmethod
2222
from cachetools.keys import hashkey
23+
from typing_extensions import Self
2324

2425
from kedro.utils import load_obj
2526

@@ -178,9 +179,9 @@ def _logger(self) -> logging.Logger:
178179
return logging.getLogger(__name__)
179180

180181
@classmethod
181-
def _load_wrapper(cls, load_func: Callable[[], _DO]) -> Callable[[], _DO]:
182+
def _load_wrapper(cls, load_func: Callable[[Self], _DO]) -> Callable[[Self], _DO]:
182183
@wraps(load_func)
183-
def load(self) -> _DO:
184+
def load(self: Self) -> _DO:
184185
self._logger.debug("Loading %s", str(self))
185186

186187
try:
@@ -200,9 +201,11 @@ def load(self) -> _DO:
200201
return load
201202

202203
@classmethod
203-
def _save_wrapper(cls, save_func: Callable[[_DI], None]) -> Callable[[_DI], None]:
204+
def _save_wrapper(
205+
cls, save_func: Callable[[Self, _DI], None]
206+
) -> Callable[[Self, _DI], None]:
204207
@wraps(save_func)
205-
def save(self, data: _DI) -> None:
208+
def save(self: Self, data: _DI) -> None:
206209
if data is None:
207210
raise DatasetError("Saving 'None' to a 'Dataset' is not allowed")
208211

@@ -226,14 +229,14 @@ def __init_subclass__(cls, **kwargs: Any) -> None:
226229
super().__init_subclass__(**kwargs)
227230

228231
if hasattr(cls, "load") and not cls.load.__qualname__.startswith("Abstract"):
229-
cls.load = cls._load_wrapper( # type: ignore[method-assign]
232+
cls.load = cls._load_wrapper( # type: ignore[assignment]
230233
cls.load
231234
if not getattr(cls.load, "__loadwrapped__", False)
232235
else cls.load.__wrapped__ # type: ignore[attr-defined]
233236
)
234237

235238
if hasattr(cls, "save") and not cls.save.__qualname__.startswith("Abstract"):
236-
cls.save = cls._save_wrapper( # type: ignore[method-assign]
239+
cls.save = cls._save_wrapper( # type: ignore[assignment]
237240
cls.save
238241
if not getattr(cls.save, "__savewrapped__", False)
239242
else cls.save.__wrapped__ # type: ignore[attr-defined]
@@ -678,9 +681,11 @@ def load(self) -> _DO:
678681
return super().load()
679682

680683
@classmethod
681-
def _save_wrapper(cls, save_func: Callable[[_DI], None]) -> Callable[[_DI], None]:
684+
def _save_wrapper(
685+
cls, save_func: Callable[[Self, _DI], None]
686+
) -> Callable[[Self, _DI], None]:
682687
@wraps(save_func)
683-
def save(self, data: _DI) -> None:
688+
def save(self: Self, data: _DI) -> None:
684689
self._version_cache.clear()
685690
save_version = (
686691
self.resolve_save_version()

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dependencies = [
3131
"rich>=12.0,<14.0",
3232
"rope>=0.21,<2.0", # subject to LGPLv3 license
3333
"toml>=0.10.0",
34+
"typing_extensions>=4.0",
3435
"graphlib_backport>=1.0.0; python_version < '3.9'",
3536
]
3637
keywords = [

0 commit comments

Comments
 (0)