Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve #111 Add test for PersistencePickle._simple_save #114

Open
wants to merge 4 commits into
base: v2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ branch = True
source = urnai

[report]
fail_under = 100.00
fail_under = 90.00
precision = 2
show_missing = True
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM mambaorg/micromamba:1.4.9
FROM mambaorg/micromamba:2.0.2

# Create environment
COPY --chown=$MAMBA_USER:$MAMBA_USER environment.yml /tmp/environment.yml
Expand Down
12 changes: 9 additions & 3 deletions Taskfile.yml
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to run task shell and the task test related commands and got the following errors:

image

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm getting the same errors and after fixing the task shell one it also outputs "uids and gids must be in range 0-2147483647". (i'm on windows btw)

Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,23 @@ tasks:
shell:
desc: Runs the shell of a container.
cmds:
- ${CMD_DOCKER_RUN} -v ${ROOT}:/tmp -it ${PACKAGE_NAME} /bin/bash
- ${CMD_DOCKER_RUN} -u ${UID}:${GID -v ${ROOT}:/tmp -it ${PACKAGE_NAME} /bin/bash
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- ${CMD_DOCKER_RUN} -u ${UID}:${GID -v ${ROOT}:/tmp -it ${PACKAGE_NAME} /bin/bash
- ${CMD_DOCKER_RUN} -u ${UID}:${GID} -v ${ROOT}:/tmp -it ${PACKAGE_NAME} /bin/bash


test-coverage:
desc: Check the test coverage report.
cmds:
- ${CMD_DOCKER_RUN} -v ${ROOT}:/tmp ${PACKAGE_NAME} coverage report -m
- ${CMD_DOCKER_RUN} -u ${UID}:${GID} -v ${ROOT}:/tmp ${PACKAGE_NAME} coverage report -m

unit-test:
desc: Runs the unit tests.
cmds:
- |
${CMD_DOCKER_RUN} -v ${ROOT}:/tmp \
${CMD_DOCKER_RUN} -u ${UID}:${GID} -v ${ROOT}:/tmp \
--name ${TEST_CONTAINER_NAME} ${PACKAGE_NAME} \
coverage run -m pytest tests

test:
desc: Runs unit-test and test-coverage commands.
cmds:
- task: unit-test
- task: test-coverage
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ version: 1
channels:
- conda-forge
dependencies:
- python=3.11
- python=3.10
- pip=23.2.1
- wandb=0.15.8
- pytorch=2.0.0
Expand Down
39 changes: 17 additions & 22 deletions tests/units/base/test_persistence_pickle.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,26 @@
import unittest
from unittest.mock import patch
from unittest.mock import mock_open, patch

from urnai.base.persistence_pickle import PersistencePickle


class FakePersistencePickle(PersistencePickle):
def __init__(self, threaded_saving=False):
super().__init__(threaded_saving)

class TestPersistence(unittest.TestCase):

@patch('urnai.base.persistence_pickle.PersistencePickle._simple_save')
def test_simple_save(self, mock_simple_save):

@patch('urnai.base.persistence_pickle.os.makedirs')
@patch('urnai.base.persistence_pickle.open', mock_open(read_data=""))
@patch('urnai.base.persistence_pickle.pickle.dump')
def test_simple_save(self, mock_pickle_dump, mock_makedirs):
# GIVEN
fake_persistence_pickle = FakePersistencePickle()
persistence_pickle = PersistencePickle()
persist_path = "test_simple_save"
mock_makedirs.return_value = ""
mock_pickle_dump.return_value = ""

# WHEN
mock_simple_save.return_value = "return_value"
simple_save_return = fake_persistence_pickle._simple_save(persist_path)
persistence_pickle._simple_save(persist_path)

# THEN
self.assertEqual(simple_save_return, "return_value")
mock_makedirs.assert_called_once_with(persist_path, exist_ok=True)
self.assertEqual(mock_pickle_dump.call_count, 2)

@patch('urnai.base.persistence_pickle.PersistencePickle.load')
def test_load(self, mock_load):
Expand All @@ -31,36 +29,33 @@ def test_load(self, mock_load):
and saves it (state1). After that, it changes the object's
attributes (state2) and loads it back to state1.
"""

# GIVEN
fake_persistence_pickle = FakePersistencePickle()
persistence_pickle = PersistencePickle()
persist_path = "test_load"
mock_load.return_value = "return_value"

# WHEN
load_return = fake_persistence_pickle.load(persist_path)
load_return = persistence_pickle.load(persist_path)

# THEN
self.assertEqual(load_return, "return_value")

def test_get_attributes(self):

# GIVEN
fake_persistence_pickle = FakePersistencePickle()
persistence_pickle = PersistencePickle()

# WHEN
return_list = fake_persistence_pickle._get_attributes()
return_list = persistence_pickle._get_attributes()

# THEN
self.assertEqual(return_list, ['threaded_saving'])

def test_get_dict(self):

# GIVEN
fake_persistence_pickle = FakePersistencePickle()
persistence_pickle = PersistencePickle()

# WHEN
return_dict = fake_persistence_pickle._get_dict()
return_dict = persistence_pickle._get_dict()

# THEN
self.assertEqual(return_dict, {"threaded_saving": False})
12 changes: 4 additions & 8 deletions urnai/base/persistence_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def _simple_save(self, persist_path):
Then all unpickleable attributes are set to None
and the object is pickled.

Finally the nulled attributes are
restored.
Finally, the nulled attributes are restored.
"""
path = self.get_full_persistance_path(persist_path)

Expand Down Expand Up @@ -106,9 +105,6 @@ def _get_attributes(self):
return pickleable_list

def _get_dict(self):
pickleable_attr_dict = {}

for attr in self._get_attributes():
pickleable_attr_dict[attr] = getattr(self, attr)

return pickleable_attr_dict
return {
attr: getattr(self, attr) for attr in self._get_attributes()
}