From 9fddad0180dcdd39785f1a729cd831bd07378849 Mon Sep 17 00:00:00 2001
From: Deepyaman Datta <deepyaman.datta@utexas.edu>
Date: Wed, 5 Jun 2024 18:42:06 -0600
Subject: [PATCH] Update type hints for wrapper methods, adding Self

Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu>
---
 kedro/io/core.py | 21 +++++++++++++--------
 pyproject.toml   |  1 +
 2 files changed, 14 insertions(+), 8 deletions(-)

diff --git a/kedro/io/core.py b/kedro/io/core.py
index 4d05a14ff6..1ba0a2f2b4 100644
--- a/kedro/io/core.py
+++ b/kedro/io/core.py
@@ -20,6 +20,7 @@
 
 from cachetools import Cache, cachedmethod
 from cachetools.keys import hashkey
+from typing_extensions import Self
 
 from kedro.utils import load_obj
 
@@ -178,9 +179,9 @@ def _logger(self) -> logging.Logger:
         return logging.getLogger(__name__)
 
     @classmethod
-    def _load_wrapper(cls, load_func: Callable[[], _DO]) -> Callable[[], _DO]:
+    def _load_wrapper(cls, load_func: Callable[[Self], _DO]) -> Callable[[Self], _DO]:
         @wraps(load_func)
-        def load(self) -> _DO:
+        def load(self: Self) -> _DO:
             self._logger.debug("Loading %s", str(self))
 
             try:
@@ -200,9 +201,11 @@ def load(self) -> _DO:
         return load
 
     @classmethod
-    def _save_wrapper(cls, save_func: Callable[[_DI], None]) -> Callable[[_DI], None]:
+    def _save_wrapper(
+        cls, save_func: Callable[[Self, _DI], None]
+    ) -> Callable[[Self, _DI], None]:
         @wraps(save_func)
-        def save(self, data: _DI) -> None:
+        def save(self: Self, data: _DI) -> None:
             if data is None:
                 raise DatasetError("Saving 'None' to a 'Dataset' is not allowed")
 
@@ -226,14 +229,14 @@ def __init_subclass__(cls, **kwargs: Any) -> None:
         super().__init_subclass__(**kwargs)
 
         if hasattr(cls, "load") and not cls.load.__qualname__.startswith("Abstract"):
-            cls.load = cls._load_wrapper(  # type: ignore[method-assign]
+            cls.load = cls._load_wrapper(  # type: ignore[assignment]
                 cls.load
                 if not getattr(cls.load, "__loadwrapped__", False)
                 else cls.load.__wrapped__  # type: ignore[attr-defined]
             )
 
         if hasattr(cls, "save") and not cls.save.__qualname__.startswith("Abstract"):
-            cls.save = cls._save_wrapper(  # type: ignore[method-assign]
+            cls.save = cls._save_wrapper(  # type: ignore[assignment]
                 cls.save
                 if not getattr(cls.save, "__savewrapped__", False)
                 else cls.save.__wrapped__  # type: ignore[attr-defined]
@@ -678,9 +681,11 @@ def load(self) -> _DO:
         return super().load()
 
     @classmethod
-    def _save_wrapper(cls, save_func: Callable[[_DI], None]) -> Callable[[_DI], None]:
+    def _save_wrapper(
+        cls, save_func: Callable[[Self, _DI], None]
+    ) -> Callable[[Self, _DI], None]:
         @wraps(save_func)
-        def save(self, data: _DI) -> None:
+        def save(self: Self, data: _DI) -> None:
             self._version_cache.clear()
             save_version = (
                 self.resolve_save_version()
diff --git a/pyproject.toml b/pyproject.toml
index 6ece4851cd..e7d80de0ad 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -31,6 +31,7 @@ dependencies = [
     "rich>=12.0,<14.0",
     "rope>=0.21,<2.0",  # subject to LGPLv3 license
     "toml>=0.10.0",
+    "typing_extensions>=4.0",
     "graphlib_backport>=1.0.0; python_version < '3.9'",
 ]
 keywords = [