Skip to content

Commit ba8255a

Browse files
committed
Auto-convert from _load/_save to load/save
Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu>
1 parent c25bab3 commit ba8255a

6 files changed

+52
-111
lines changed

kedro/io/cached_dataset.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -98,22 +98,18 @@ def _describe(self) -> dict[str, Any]:
9898
"cache": self._cache._describe(),
9999
}
100100

101-
def _load(self) -> Any:
101+
def load(self) -> Any:
102102
data = self._cache.load() if self._cache.exists() else self._dataset.load()
103103

104104
if not self._cache.exists():
105105
self._cache.save(data)
106106

107107
return data
108108

109-
load = _load
110-
111-
def _save(self, data: Any) -> None:
109+
def save(self, data: Any) -> None:
112110
self._dataset.save(data)
113111
self._cache.save(data)
114112

115-
save = _save
116-
117113
def _exists(self) -> bool:
118114
return self._cache.exists() or self._dataset.exists()
119115

kedro/io/core.py

+43-66
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,33 @@ def from_config(
178178
def _logger(self) -> logging.Logger:
179179
return logging.getLogger(__name__)
180180

181+
def __str__(self) -> str:
182+
def _to_str(obj: Any, is_root: bool = False) -> str:
183+
"""Returns a string representation where
184+
1. The root level (i.e. the Dataset.__init__ arguments) are
185+
formatted like Dataset(key=value).
186+
2. Dictionaries have the keys alphabetically sorted recursively.
187+
3. None values are not shown.
188+
"""
189+
190+
fmt = "{}={}" if is_root else "'{}': {}" # 1
191+
192+
if isinstance(obj, dict):
193+
sorted_dict = sorted(obj.items(), key=lambda pair: str(pair[0])) # 2
194+
195+
text = ", ".join(
196+
fmt.format(key, _to_str(value)) # 2
197+
for key, value in sorted_dict
198+
if value is not None # 3
199+
)
200+
201+
return text if is_root else "{" + text + "}" # 1
202+
203+
# not a dictionary
204+
return str(obj)
205+
206+
return f"{type(self).__name__}({_to_str(self._describe(), True)})"
207+
181208
@classmethod
182209
def _load_wrapper(cls, load_func: Callable[[Self], _DO]) -> Callable[[Self], _DO]:
183210
@wraps(load_func)
@@ -228,6 +255,12 @@ def save(self: Self, data: _DI) -> None:
228255
def __init_subclass__(cls, **kwargs: Any) -> None:
229256
super().__init_subclass__(**kwargs)
230257

258+
if hasattr(cls, "_load") and not cls._load.__qualname__.startswith("Abstract"):
259+
cls.load = cls._load # type: ignore[method-assign]
260+
261+
if hasattr(cls, "_save") and not cls._save.__qualname__.startswith("Abstract"):
262+
cls.save = cls._save # type: ignore[method-assign]
263+
231264
if hasattr(cls, "load") and not cls.load.__qualname__.startswith("Abstract"):
232265
cls.load = cls._load_wrapper( # type: ignore[assignment]
233266
cls.load
@@ -242,6 +275,7 @@ def __init_subclass__(cls, **kwargs: Any) -> None:
242275
else cls.save.__wrapped__ # type: ignore[attr-defined]
243276
)
244277

278+
@abc.abstractmethod
245279
def load(self) -> _DO:
246280
"""Loads data by delegation to the provided load method.
247281
@@ -252,21 +286,12 @@ def load(self) -> _DO:
252286
DatasetError: When underlying load method raises error.
253287
254288
"""
289+
raise NotImplementedError(
290+
f"'{self.__class__.__name__}' is a subclass of AbstractDataset and "
291+
f"it must implement the 'load' method"
292+
)
255293

256-
self._logger.debug("Loading %s", str(self))
257-
258-
try:
259-
return self._load()
260-
except DatasetError:
261-
raise
262-
except Exception as exc:
263-
# This exception handling is by design as the composed data sets
264-
# can throw any type of exception.
265-
message = (
266-
f"Failed while loading data from data set {str(self)}.\n{str(exc)}"
267-
)
268-
raise DatasetError(message) from exc
269-
294+
@abc.abstractmethod
270295
def save(self, data: _DI) -> None:
271296
"""Saves data by delegation to the provided save method.
272297
@@ -277,59 +302,11 @@ def save(self, data: _DI) -> None:
277302
DatasetError: when underlying save method raises error.
278303
FileNotFoundError: when save method got file instead of dir, on Windows.
279304
NotADirectoryError: when save method got file instead of dir, on Unix.
280-
"""
281-
282-
if data is None:
283-
raise DatasetError("Saving 'None' to a 'Dataset' is not allowed")
284-
285-
try:
286-
self._logger.debug("Saving %s", str(self))
287-
self._save(data)
288-
except (DatasetError, FileNotFoundError, NotADirectoryError):
289-
raise
290-
except Exception as exc:
291-
message = f"Failed while saving data to data set {str(self)}.\n{str(exc)}"
292-
raise DatasetError(message) from exc
293-
294-
def __str__(self) -> str:
295-
def _to_str(obj: Any, is_root: bool = False) -> str:
296-
"""Returns a string representation where
297-
1. The root level (i.e. the Dataset.__init__ arguments) are
298-
formatted like Dataset(key=value).
299-
2. Dictionaries have the keys alphabetically sorted recursively.
300-
3. None values are not shown.
301-
"""
302-
303-
fmt = "{}={}" if is_root else "'{}': {}" # 1
304-
305-
if isinstance(obj, dict):
306-
sorted_dict = sorted(obj.items(), key=lambda pair: str(pair[0])) # 2
307305
308-
text = ", ".join(
309-
fmt.format(key, _to_str(value)) # 2
310-
for key, value in sorted_dict
311-
if value is not None # 3
312-
)
313-
314-
return text if is_root else "{" + text + "}" # 1
315-
316-
# not a dictionary
317-
return str(obj)
318-
319-
return f"{type(self).__name__}({_to_str(self._describe(), True)})"
320-
321-
@abc.abstractmethod
322-
def _load(self) -> _DO:
323-
raise NotImplementedError(
324-
f"'{self.__class__.__name__}' is a subclass of AbstractDataset and "
325-
f"it must implement the '_load' method"
326-
)
327-
328-
@abc.abstractmethod
329-
def _save(self, data: _DI) -> None:
306+
"""
330307
raise NotImplementedError(
331308
f"'{self.__class__.__name__}' is a subclass of AbstractDataset and "
332-
f"it must implement the '_save' method"
309+
f"it must implement the 'save' method"
333310
)
334311

335312
@abc.abstractmethod
@@ -682,7 +659,7 @@ def _get_versioned_path(self, version: str) -> PurePosixPath:
682659
return self._filepath / version / self._filepath.name
683660

684661
def load(self) -> _DO:
685-
return super().load()
662+
return super().load() # type: ignore[safe-super]
686663

687664
@classmethod
688665
def _save_wrapper(
@@ -724,7 +701,7 @@ def save(self, data: _DI) -> None:
724701
self._version_cache.clear()
725702
save_version = self.resolve_save_version() # Make sure last save version is set
726703
try:
727-
super().save(data)
704+
super().save(data) # type: ignore[safe-super]
728705
except (FileNotFoundError, NotADirectoryError) as err:
729706
# FileNotFoundError raised in Win, NotADirectoryError raised in Unix
730707
_default_version = "YYYY-MM-DDThh.mm.ss.sssZ"

kedro/io/lambda_dataset.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -49,26 +49,22 @@ def _to_str(func: Any) -> str | None:
4949

5050
return descr
5151

52-
def _load(self) -> Any:
52+
def load(self) -> Any:
5353
if not self.__load:
5454
raise DatasetError(
5555
"Cannot load data set. No 'load' function "
5656
"provided when LambdaDataset was created."
5757
)
5858
return self.__load()
5959

60-
load = _load
61-
62-
def _save(self, data: Any) -> None:
60+
def save(self, data: Any) -> None:
6361
if not self.__save:
6462
raise DatasetError(
6563
"Cannot save to data set. No 'save' function "
6664
"provided when LambdaDataset was created."
6765
)
6866
self.__save(data)
6967

70-
save = _save
71-
7268
def _exists(self) -> bool:
7369
if not self.__exists:
7470
return super()._exists()

kedro/io/memory_dataset.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -57,24 +57,20 @@ def __init__(
5757
self.metadata = metadata
5858
self._EPHEMERAL = True
5959
if data is not _EMPTY:
60-
self._save(data)
60+
self.save.__wrapped__(self, data) # type: ignore[attr-defined]
6161

62-
def _load(self) -> Any:
62+
def load(self) -> Any:
6363
if self._data is _EMPTY:
6464
raise DatasetError("Data for MemoryDataset has not been saved yet.")
6565

6666
copy_mode = self._copy_mode or _infer_copy_mode(self._data)
6767
data = _copy_with_mode(self._data, copy_mode=copy_mode)
6868
return data
6969

70-
load = _load
71-
72-
def _save(self, data: Any) -> None:
70+
def save(self, data: Any) -> None:
7371
copy_mode = self._copy_mode or _infer_copy_mode(data)
7472
self._data = _copy_with_mode(data, copy_mode=copy_mode)
7573

76-
save = _save
77-
7874
def _exists(self) -> bool:
7975
return self._data is not _EMPTY
8076

kedro/io/shared_memory_dataset.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,10 @@ def __getattr__(self, name: str) -> Any:
3434
raise AttributeError()
3535
return getattr(self.shared_memory_dataset, name) # pragma: no cover
3636

37-
def _load(self) -> Any:
37+
def load(self) -> Any:
3838
return self.shared_memory_dataset.load()
3939

40-
load = _load
41-
42-
def _save(self, data: Any) -> None:
40+
def save(self, data: Any) -> None:
4341
"""Calls save method of a shared MemoryDataset in SyncManager."""
4442
try:
4543
self.shared_memory_dataset.save(data)
@@ -54,8 +52,6 @@ def _save(self, data: Any) -> None:
5452
) from serialisation_exc
5553
raise exc # pragma: no cover
5654

57-
save = _save
58-
5955
def _describe(self) -> dict[str, Any]:
6056
"""SharedMemoryDataset doesn't have any constructor argument to return."""
6157
return {}

tests/io/test_core.py

-20
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,10 @@ def _exists(self) -> bool:
5858
def _load(self):
5959
return pd.read_csv(self._filepath)
6060

61-
load = _load
62-
6361
def _save(self, data: str) -> None:
6462
with open(self._filepath, mode="w") as file:
6563
file.write(data)
6664

67-
save = _save
68-
6965

7066
class MyVersionedDataset(AbstractVersionedDataset[str, str]):
7167
def __init__( # noqa: PLR0913
@@ -96,16 +92,12 @@ def _load(self) -> str:
9692
with self._fs.open(load_path, mode="r") as fs_file:
9793
return fs_file.read()
9894

99-
load = _load
100-
10195
def _save(self, data: str) -> None:
10296
save_path = get_filepath_str(self._get_save_path(), self._protocol)
10397

10498
with self._fs.open(save_path, mode="w") as fs_file:
10599
fs_file.write(data)
106100

107-
save = _save
108-
109101
def _exists(self) -> bool:
110102
try:
111103
load_path = get_filepath_str(self._get_load_path(), self._protocol)
@@ -143,16 +135,12 @@ def _load(self) -> str:
143135
with self._fs.open(load_path, mode="r") as fs_file:
144136
return fs_file.read()
145137

146-
load = _load
147-
148138
def _save(self, data: str) -> None:
149139
save_path = get_filepath_str(self._get_save_path(), self._protocol)
150140

151141
with self._fs.open(save_path, mode="w") as fs_file:
152142
fs_file.write(data)
153143

154-
save = _save
155-
156144
def _exists(self) -> bool:
157145
load_path = get_filepath_str(self._get_load_path(), self._protocol)
158146
# no try catch - will return a VersionNotFoundError to be caught be AbstractVersionedDataset.exists()
@@ -458,14 +446,10 @@ def _exists(self) -> bool:
458446
def _load(self):
459447
return pd.read_csv(self._filepath)
460448

461-
# load = _load
462-
463449
def _save(self, data: str) -> None:
464450
with open(self._filepath, mode="w") as file:
465451
file.write(data)
466452

467-
# save = _save
468-
469453

470454
class MyLegacyVersionedDataset(AbstractVersionedDataset[str, str]):
471455
def __init__( # noqa: PLR0913
@@ -496,16 +480,12 @@ def _load(self) -> str:
496480
with self._fs.open(load_path, mode="r") as fs_file:
497481
return fs_file.read()
498482

499-
# load = _load
500-
501483
def _save(self, data: str) -> None:
502484
save_path = get_filepath_str(self._get_save_path(), self._protocol)
503485

504486
with self._fs.open(save_path, mode="w") as fs_file:
505487
fs_file.write(data)
506488

507-
# save = _save
508-
509489
def _exists(self) -> bool:
510490
try:
511491
load_path = get_filepath_str(self._get_load_path(), self._protocol)

0 commit comments

Comments
 (0)