Skip to content
This repository was archived by the owner on Jul 2, 2021. It is now read-only.

Support keys assign in GetterDataset #571

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
from chainercv.chainer_experimental.datasets.sliceable.sliceable_dataset \
import _as_indices
from chainercv.chainer_experimental.datasets.sliceable import SliceableDataset


def _as_tuple(t):
if isinstance(t, tuple):
return t
else:
return t,


class GetterDataset(SliceableDataset):
"""A sliceable dataset class that is defined with getters.

Expand Down Expand Up @@ -49,13 +44,23 @@ class GetterDataset(SliceableDataset):
def __init__(self):
self._keys = []
self._getters = []
self._return_tuple = True

def __len__(self):
raise NotImplementedError

@property
def keys(self):
return tuple(key for key, _, _ in self._keys)
if self._return_tuple:
return tuple(key for key, _, _ in self._keys)
else:
return self._keys[0][0]

@keys.setter
def keys(self, keys):
self._keys = [self._keys[key_index]
for key_index in _as_indices(keys, self.keys)]
self._return_tuple = isinstance(keys, (list, tuple))

def add_getter(self, keys, getter):
"""Register a getter function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,27 @@ def _as_tuple(t):
return t,


def _as_indices(keys, key_names):
keys = _as_tuple(keys)
key_names = _as_tuple(key_names)

for key in keys:
if isinstance(key, int):
key_index = key
if key_index < 0:
key_index += len(key_names)
if key_index not in range(0, len(key_names)):
raise IndexError(
'index {} is out of bounds for keys with size {}'.format(
key, len(key_names)))
else:
try:
key_index = key_names.index(key)
except ValueError:
raise KeyError('{} does not exists'.format(key))
yield key_index


class SliceableDataset(chainer.dataset.DatasetMixin):
"""An abstract dataset class that supports slicing.

Expand Down Expand Up @@ -76,26 +97,8 @@ def __getitem__(self, args):
index = args
keys = self._dataset.keys

if isinstance(keys, (list, tuple)):
return_tuple = True
else:
keys, return_tuple = (keys,), False

# convert name to index
key_indices = []
for key in keys:
if isinstance(key, int):
key_index = key
if key_index >= len(self._dataset.keys):
raise IndexError('Invalid index of key')
if key_index < 0:
key_index += len(self._dataset.keys)
else:
try:
key_index = _as_tuple(self._dataset.keys).index(key)
except ValueError:
raise KeyError('{} does not exists'.format(key))
key_indices.append(key_index)
key_indices = tuple(_as_indices(keys, self._dataset.keys))
return_tuple = isinstance(keys, (list, tuple))

return SlicedDataset(
self._dataset, index,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
from collections import defaultdict
import six

from chainercv.chainer_experimental.datasets.sliceable.sliceable_dataset \
import _as_tuple
from chainercv.chainer_experimental.datasets.sliceable import SliceableDataset


def _as_tuple(t):
if isinstance(t, tuple):
return t
else:
return t,


class TupleDataset(SliceableDataset):
"""A sliceable version of :class:`chainer.datasets.TupleDataset`.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,48 @@ def test_get_example_by_keys(self):
self.assertEqual(example, ('item1(1)', 'item2(1)', 'item3(1)'))
self.assertEqual(self.dataset.count, 2)

def test_set_keys_single_name(self):
self.dataset.keys = 'item0'
self.assertEqual(self.dataset.keys, 'item0')
self.assertEqual(self.dataset[1], 'item0(1)')

def test_set_keys_single_index(self):
self.dataset.keys = 0
self.assertEqual(self.dataset.keys, 'item0')
self.assertEqual(self.dataset[1], 'item0(1)')

def test_set_keys_single_tuple_name(self):
self.dataset.keys = ('item1',)
self.assertEqual(self.dataset.keys, ('item1',))
self.assertEqual(self.dataset[2], ('item1(2)',))

def test_set_keys_single_tuple_index(self):
self.dataset.keys = (1,)
self.assertEqual(self.dataset.keys, ('item1',))
self.assertEqual(self.dataset[2], ('item1(2)',))

def test_set_keys_multiple_name(self):
self.dataset.keys = ('item0', 'item2')
self.assertEqual(self.dataset.keys, ('item0', 'item2'))
self.assertEqual(self.dataset[3], ('item0(3)', 'item2(3)'))

def test_set_keys_multiple_index(self):
self.dataset.keys = (0, 2)
self.assertEqual(self.dataset.keys, ('item0', 'item2'))
self.assertEqual(self.dataset[3], ('item0(3)', 'item2(3)'))

def test_set_keys_multiple_mixed(self):
self.dataset.keys = ('item0', 2)
self.assertEqual(self.dataset.keys, ('item0', 'item2'))
self.assertEqual(self.dataset[3], ('item0(3)', 'item2(3)'))

def test_set_keys_invalid_name(self):
with self.assertRaises(KeyError):
self.dataset.keys = 'invalid'

def test_set_keys_invalid_index(self):
with self.assertRaises(IndexError):
self.dataset.keys = 4


testing.run_module(__name__, __file__)