Skip to content
This repository was archived by the owner on Jul 2, 2021. It is now read-only.

Support keys assign in GetterDataset #571

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,40 @@ class GetterDataset(SliceableDataset):
def __init__(self):
self._keys = []
self._getters = []
self._return_tuple = True

def __len__(self):
raise NotImplementedError

@property
def keys(self):
return tuple(key for key, _, _ in self._keys)
if self._return_tuple:
return tuple(key for key, _, _ in self._keys)
else:
return self._keys[0][0]

@keys.setter
def keys(self, keys):
if isinstance(keys, (list, tuple)):
self._return_tuple = True
else:
keys, self._return_tuple = (keys,), False

new_keys = []
for key in keys:
if isinstance(key, int):
key_index = key
if key_index >= len(self._keys):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about

if key_index < 0:
    key_index += len(self._keys)
if key_index >= len(self._keys) or key_index < 0:
    raise IndexError('index {} is out of bounds for keys with size {}'.format(key, len(self._keys))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message would be similar to NumPy.

raise IndexError('Invalid index of key')
if key_index < 0:
key_index += len(self._keys)
else:
try:
key_index = _as_tuple(self.keys).index(key)
except ValueError:
raise KeyError('{} does not exists'.format(key))
new_keys.append(self._keys[key_index])
self._keys = new_keys

def add_getter(self, keys, getter):
"""Register a getter function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,48 @@ def test_get_example_by_keys(self):
self.assertEqual(example, ('item1(1)', 'item2(1)', 'item3(1)'))
self.assertEqual(self.dataset.count, 2)

def test_set_keys_single_name(self):
self.dataset.keys = 'item0'
self.assertEqual(self.dataset.keys, 'item0')
self.assertEqual(self.dataset[1], 'item0(1)')

def test_set_keys_single_index(self):
self.dataset.keys = 0
self.assertEqual(self.dataset.keys, 'item0')
self.assertEqual(self.dataset[1], 'item0(1)')

def test_set_keys_single_tuple_name(self):
self.dataset.keys = ('item1',)
self.assertEqual(self.dataset.keys, ('item1',))
self.assertEqual(self.dataset[2], ('item1(2)',))

def test_set_keys_single_tuple_index(self):
self.dataset.keys = (1,)
self.assertEqual(self.dataset.keys, ('item1',))
self.assertEqual(self.dataset[2], ('item1(2)',))

def test_set_keys_multiple_name(self):
self.dataset.keys = ('item0', 'item2')
self.assertEqual(self.dataset.keys, ('item0', 'item2'))
self.assertEqual(self.dataset[3], ('item0(3)', 'item2(3)'))

def test_set_keys_multiple_index(self):
self.dataset.keys = (0, 2)
self.assertEqual(self.dataset.keys, ('item0', 'item2'))
self.assertEqual(self.dataset[3], ('item0(3)', 'item2(3)'))

def test_set_keys_multiple_mixed(self):
self.dataset.keys = ('item0', 2)
self.assertEqual(self.dataset.keys, ('item0', 'item2'))
self.assertEqual(self.dataset[3], ('item0(3)', 'item2(3)'))

def test_set_keys_invalid_name(self):
with self.assertRaises(KeyError):
self.dataset.keys = 'invalid'

def test_set_keys_invalid_index(self):
with self.assertRaises(IndexError):
self.dataset.keys = 4


testing.run_module(__name__, __file__)