diff --git a/chainercv/extensions/evaluator/detection_voc_evaluator.py b/chainercv/extensions/evaluator/detection_voc_evaluator.py index 945774e2a0..1bd4b38df7 100644 --- a/chainercv/extensions/evaluator/detection_voc_evaluator.py +++ b/chainercv/extensions/evaluator/detection_voc_evaluator.py @@ -5,7 +5,7 @@ import chainer.training.extensions from chainercv.evaluations import eval_detection_voc -from chainercv.utils import apply_prediction_to_iterator +from chainercv.utils import apply_to_iterator class DetectionVOCEvaluator(chainer.training.extensions.Evaluator): @@ -72,17 +72,17 @@ def evaluate(self): else: it = copy.copy(iterator) - imgs, pred_values, gt_values = apply_prediction_to_iterator( + in_values, out_values, rest_values = apply_to_iterator( target.predict, it) - # delete unused iterator explicitly - del imgs + # delete unused iterators explicitly + del in_values - pred_bboxes, pred_labels, pred_scores = pred_values + pred_bboxes, pred_labels, pred_scores = out_values - if len(gt_values) == 3: - gt_bboxes, gt_labels, gt_difficults = gt_values - elif len(gt_values) == 2: - gt_bboxes, gt_labels = gt_values + if len(rest_values) == 3: + gt_bboxes, gt_labels, gt_difficults = rest_values + elif len(rest_values) == 2: + gt_bboxes, gt_labels = rest_values gt_difficults = None result = eval_detection_voc( diff --git a/chainercv/extensions/evaluator/semantic_segmentation_evaluator.py b/chainercv/extensions/evaluator/semantic_segmentation_evaluator.py index 5b39dfa26d..9ac53552e5 100644 --- a/chainercv/extensions/evaluator/semantic_segmentation_evaluator.py +++ b/chainercv/extensions/evaluator/semantic_segmentation_evaluator.py @@ -5,7 +5,7 @@ import chainer.training.extensions from chainercv.evaluations import eval_semantic_segmentation -from chainercv.utils import apply_prediction_to_iterator +from chainercv.utils import apply_to_iterator class SemanticSegmentationEvaluator(chainer.training.extensions.Evaluator): @@ -79,13 +79,13 @@ def evaluate(self): else: it = copy.copy(iterator) - imgs, pred_values, gt_values = apply_prediction_to_iterator( + in_values, out_values, rest_values = apply_to_iterator( target.predict, it) - # delete unused iterator explicitly - del imgs + # delete unused iterators explicitly + del in_values - pred_labels, = pred_values - gt_labels, = gt_values + pred_labels, = out_values + gt_labels, = rest_values result = eval_semantic_segmentation(pred_labels, gt_labels) diff --git a/chainercv/utils/__init__.py b/chainercv/utils/__init__.py index 47f48e3e77..9ee7dd02ab 100644 --- a/chainercv/utils/__init__.py +++ b/chainercv/utils/__init__.py @@ -6,7 +6,7 @@ from chainercv.utils.image import read_image # NOQA from chainercv.utils.image import tile_images # NOQA from chainercv.utils.image import write_image # NOQA -from chainercv.utils.iterator import apply_prediction_to_iterator # NOQA +from chainercv.utils.iterator import apply_to_iterator # NOQA from chainercv.utils.iterator import ProgressHook # NOQA from chainercv.utils.iterator import unzip # NOQA from chainercv.utils.testing import assert_is_bbox # NOQA diff --git a/chainercv/utils/iterator/__init__.py b/chainercv/utils/iterator/__init__.py index f7c736a059..fb657c7e86 100644 --- a/chainercv/utils/iterator/__init__.py +++ b/chainercv/utils/iterator/__init__.py @@ -1,3 +1,3 @@ -from chainercv.utils.iterator.apply_prediction_to_iterator import apply_prediction_to_iterator # NOQA +from chainercv.utils.iterator.apply_to_iterator import apply_to_iterator # NOQA from chainercv.utils.iterator.progress_hook import ProgressHook # NOQA from chainercv.utils.iterator.unzip import unzip # NOQA diff --git a/chainercv/utils/iterator/apply_prediction_to_iterator.py b/chainercv/utils/iterator/apply_prediction_to_iterator.py deleted file mode 100644 index e15e5cd624..0000000000 --- a/chainercv/utils/iterator/apply_prediction_to_iterator.py +++ /dev/null @@ -1,141 +0,0 @@ -from chainercv.utils.iterator.unzip import unzip - - -def apply_prediction_to_iterator(predict, iterator, hook=None): - """Apply a prediction function/method to an iterator. - - This function applies a prediction function/method to an iterator. - It assumes that the iterator returns a batch of images or - a batch of tuples whose first element is an image. In the case that - it returns a batch of tuples, the rests are treated as ground truth - values. - - >>> imgs = next(iterator) - >>> # imgs: [img] - or - >>> batch = next(iterator) - >>> # batch: [(img, gt_val0, gt_val1)] - - This function applys :func:`predict` to a batch of images and gets - predicted value(s). :func:`predict` should take a batch of images and - return a batch of prediction values - or a tuple of batches of prediction values. - - >>> pred_vals0 = predict(imgs) - >>> # pred_vals0: [pred_val0] - or - >>> pred_vals0, pred_vals1 = predict(imgs) - >>> # pred_vals0: [pred_val0] - >>> # pred_vals1: [pred_val1] - - Here is an exmple, which applies a pretrained Faster R-CNN to - PASCAL VOC dataset. - - >>> from chainer import iterators - >>> - >>> from chainercv.datasets import VOCDetectionDataset - >>> from chainercv.links import FasterRCNNVGG16 - >>> from chainercv.utils import apply_prediction_to_iterator - >>> - >>> dataset = VOCDetectionDataset(year='2007', split='test') - >>> # next(iterator) -> [(img, gt_bbox, gt_label)] - >>> iterator = iterators.SerialIterator( - ... dataset, 2, repeat=False, shuffle=False) - >>> - >>> # model.predict([img]) -> ([pred_bbox], [pred_label], [pred_score]) - >>> model = FasterRCNNVGG16(pretrained_model='voc07') - >>> - >>> imgs, pred_values, gt_values = apply_prediction_to_iterator( - ... model.predict, iterator) - >>> - >>> # pred_values contains three iterators - >>> pred_bboxes, pred_labels, pred_scores = pred_values - >>> # gt_values contains two iterators - >>> gt_bboxes, gt_labels = gt_values - - Args: - predict: A callable that takes a batch of images and returns - prediction. - iterator (chainer.Iterator): An iterator. Each sample should have - an image as its first element. This image is passed to - :func:`predict` as an argument. - The rests are treated as ground truth values. - hook: A callable that is called after each iteration. - :obj:`imgs`, :obj:`pred_values` and :obj:`gt_values` are passed as - arguments. - Note that these values do not contain data from the previous - iterations. - - Returns: - An iterator and two tuples of iterators: - This function returns an iterator and two tuples of iterators: - :obj:`imgs`, :obj:`pred_values` and :obj:`gt_values`. - - * :obj:`imgs`: An iterator that returns an image. - * :obj:`pred_values`: A tuple of iterators. Each iterator \ - returns a corresponding predicted value. \ - For example, if :func:`predict` returns \ - :obj:`([pred_val0], [pred_val1])`, :obj:`next(pred_values[0])` \ - and :obj:`next(pred_values[1])` will be \ - :obj:`pred_val0` and :obj:`pred_val1`. - * :obj:`gt_values`: A tuple of iterators. Each iterator \ - returns a corresponding ground truth value. \ - For example, if the :obj:`iterator` returns \ - :obj:`[(img, gt_val0, gt_val1)]`, :obj:`next(gt_values[0])` \ - and :obj:`next(gt_values[1])` will be \ - :obj:`gt_val0` and :obj:`gt_val1`. \ - If the input \ - iterator does not give any ground truth values, this tuple \ - will be empty. - """ - - imgs, pred_values, gt_values = unzip( - _apply(predict, iterator, hook)) - - # imgs: iter of [img] -> iter of img - imgs = _flatten(imgs) - - # pred_values: iter of ([pred_val0], [pred_val1], ...) - # -> (iter of pred_val0, iter of pred_val1, ...) - pred_values = tuple(map(_flatten, unzip(pred_values))) - - # gt_values: iter of ([gt_val0], [gt_val1], ...) - # -> (iter of gt_val0, iter of gt_val1, ...) - gt_values = tuple(map(_flatten, unzip(gt_values))) - - return imgs, pred_values, gt_values - - -def _apply(predict, iterator, hook): - for batch in iterator: - # batch: [(img, gt_val0, gt_val1, ...)] or [img] - - imgs = [] - gt_values = [] - for sample in batch: - if isinstance(sample, tuple): - imgs.append(sample[0]) - gt_values.append(sample[1:]) - else: - imgs.append(sample) - gt_values.append(tuple()) - - # imgs: [img] - - # gt_values: [(gt_val0, gt_val1, ...)] -> ([gt_val0], [gt_val1], ...) - gt_values = tuple(list(v) for v in zip(*gt_values)) - - # pred_values: ([pred_val0], [pred_val1], ...) or [pred_val] - pred_values = predict(imgs) - if not isinstance(pred_values, tuple): - # pred_values: [pred_val] -> ([pred_val],) - pred_values = pred_values, - - if hook: - hook(imgs, pred_values, gt_values) - - yield imgs, pred_values, gt_values - - -def _flatten(iterator): - return (sample for batch in iterator for sample in batch) diff --git a/chainercv/utils/iterator/apply_to_iterator.py b/chainercv/utils/iterator/apply_to_iterator.py new file mode 100644 index 0000000000..cd0073010a --- /dev/null +++ b/chainercv/utils/iterator/apply_to_iterator.py @@ -0,0 +1,169 @@ +from chainercv.utils.iterator.unzip import unzip + + +def apply_to_iterator(func, iterator, n_input=1, hook=None): + """Apply a function/method to batches from an iterator. + + This function applies a function/method to an iterator of batches. + + It assumes that the iterator iterates over a collection of tuples + that contain inputs to :func:`func`. + Additionally, the tuples may contain values + that are not used by :func:`func`. + For convenience, we allow the iterator to iterate over a collection of + inputs that are not tuple. + Here is an illustration of the expected behavior of the iterator. + This behaviour is the same as :class:`chainer.Iterator`. + + >>> batch = next(iterator) + >>> # batch: [in_val] + or + >>> # batch: [(in_val0, ..., in_val{n_input - 1})] + or + >>> # batch: [(in_val0, ..., in_val{n_input - 1}, rest_val0, ...)] + + :func:`func` should take batch(es) of data and + return batch(es) of computed values. + Here is an illustration of the expected behavior of the function. + + >>> out_vals = func([in_val0], ..., [in_val{n_input - 1}]) + >>> # out_vals: [out_val] + or + >>> out_vals0, out_vals1, ... = func([in_val0], ..., [in_val{n_input - 1}]) + >>> # out_vals0: [out_val0] + >>> # out_vals1: [out_val1] + + With :func:`apply_to_iterator`, users can get iterator(s) of values + returned by :func:`func`. It also returns iterator(s) of input values and + values that are not used for computation. + + >>> in_values, out_values, rest_values = apply_to_iterator( + >>> func, iterator, n_input) + >>> # in_values: (iter of in_val0, ..., iter of in_val{n_input - 1}) + >>> # out_values: (iter of out_val0, ...) + >>> # rest_values: (iter of rest_val0, ...) + + Here is an exmple, which applies a pretrained Faster R-CNN to + PASCAL VOC dataset. + + >>> from chainer import iterators + >>> + >>> from chainercv.datasets import VOCBBoxDataset + >>> from chainercv.links import FasterRCNNVGG16 + >>> from chainercv.utils import apply_to_iterator + >>> + >>> dataset = VOCBBoxDataset(year='2007', split='test') + >>> # next(iterator) -> [(img, gt_bbox, gt_label)] + >>> iterator = iterators.SerialIterator( + ... dataset, 2, repeat=False, shuffle=False) + >>> + >>> # model.predict([img]) -> ([pred_bbox], [pred_label], [pred_score]) + >>> model = FasterRCNNVGG16(pretrained_model='voc07') + >>> + >>> in_values, out_values, rest_values = apply_to_iterator( + ... model.predict, iterator) + >>> + >>> # in_values contains one iterator + >>> imgs, = in_values + >>> # out_values contains three iterators + >>> pred_bboxes, pred_labels, pred_scores = out_values + >>> # rest_values contains two iterators + >>> gt_bboxes, gt_labels = rest_values + + Args: + func: A callable that takes batch(es) of input data and returns + computed data. + iterator (iterator): An iterator of batches. + The first :obj:`n_input` elements in each sample are + treated as input values. They are passed to :obj:`func`. + n_input (int): The number of input data. The default value is :obj:`1`. + hook: A callable that is called after each iteration. + :obj:`in_values`, :obj:`out_values`, and :obj:`rest_values` + are passed as arguments. + Note that these values do not contain data from the previous + iterations. + + Returns: + Three tuples of iterators: + This function returns three tuples of iterators: + :obj:`in_values`, :obj:`out_values` and :obj:`rest_values`. + + * :obj:`in_values`: A tuple of iterators. Each iterator \ + returns a corresponding input value. \ + For example, if :func:`func` takes \ + :obj:`[in_val0], [in_val1]`, :obj:`next(in_values[0])` \ + and :obj:`next(in_values[1])` will be \ + :obj:`in_val0` and :obj:`in_val1`. + * :obj:`out_values`: A tuple of iterators. Each iterator \ + returns a corresponding computed value. \ + For example, if :func:`func` returns \ + :obj:`([out_val0], [out_val1])`, :obj:`next(out_values[0])` \ + and :obj:`next(out_values[1])` will be \ + :obj:`out_val0` and :obj:`out_val1`. + * :obj:`rest_values`: A tuple of iterators. Each iterator \ + returns a corresponding rest value. \ + For example, if the :obj:`iterator` returns \ + :obj:`[(in_val0, in_val1, rest_val0, rest_val1)]`, \ + :obj:`next(rest_values[0])` \ + and :obj:`next(rest_values[1])` will be \ + :obj:`rest_val0` and :obj:`rest_val1`. \ + If the input \ + iterator does not give any rest values, this tuple \ + will be empty. + """ + + in_values, out_values, rest_values = unzip( + _apply(func, iterator, n_input, hook)) + + # in_values: iter of ([in_val0], [in_val1], ...) + # -> (iter of in_val0, iter of in_val1, ...) + in_values = tuple(map(_flatten, unzip(in_values))) + + # out_values: iter of ([out_val0], [out_val1], ...) + # -> (iter of out_val0, iter of out_val1, ...) + out_values = tuple(map(_flatten, unzip(out_values))) + + # rest_values: iter of ([rest_val0], [rest_val1], ...) + # -> (iter of rest_val0, iter of rest_val1, ...) + rest_values = tuple(map(_flatten, unzip(rest_values))) + + return in_values, out_values, rest_values + + +def _apply(func, iterator, n_input, hook): + for batch in iterator: + # batch: [(in_val0, in_val1, ... , rest_val0, rest_val1, ...)] or + # [in_val] + + in_values = [] + rest_values = [] + for sample in batch: + if isinstance(sample, tuple): + in_values.append(sample[0:n_input]) + rest_values.append(sample[n_input:]) + else: + in_values.append((sample,)) + rest_values.append(()) + + # in_values: [(in_val0, in_val1, ...)] + # -> ([in_val0], [in_val1], ...) + in_values = tuple(list(v) for v in zip(*in_values)) + + # rest_values: [(rest_val0, rest_val1, ...)] + # -> ([rest_val0], [rest_val1], ...) + rest_values = tuple(list(v) for v in zip(*rest_values)) + + # out_values: ([out_val0], [out_val1], ...) or [out_val] + out_values = func(*in_values) + if not isinstance(out_values, tuple): + # pred_values: [out_val] -> ([out_val],) + out_values = out_values, + + if hook: + hook(in_values, out_values, rest_values) + + yield in_values, out_values, rest_values + + +def _flatten(iterator): + return (sample for batch in iterator for sample in batch) diff --git a/chainercv/utils/iterator/progress_hook.py b/chainercv/utils/iterator/progress_hook.py index f895361e20..4835752a33 100644 --- a/chainercv/utils/iterator/progress_hook.py +++ b/chainercv/utils/iterator/progress_hook.py @@ -19,16 +19,16 @@ def __init__(self, n_total=None): self.start = time.time() self.n_processed = 0 - def __call__(self, imgs, pred_values, gt_values): - self.n_processed += len(imgs) + def __call__(self, in_values, out_values, rest_values): + self.n_processed += len(in_values[0]) fps = self.n_processed / (time.time() - self.start) if self.n_total is not None: sys.stdout.write( - '\r{:d} of {:d} images, {:.2f} FPS'.format( + '\r{:d} of {:d} samples, {:.2f} samples/sec'.format( self.n_processed, self.n_total, fps)) else: sys.stdout.write( - '\r{:d} images, {:.2f} FPS'.format( + '\r{:d} samples, {:.2f} samples/sec'.format( self.n_processed, fps)) sys.stdout.flush() diff --git a/docs/source/reference/utils.rst b/docs/source/reference/utils.rst index 3b952002b4..2d6a9b83d4 100644 --- a/docs/source/reference/utils.rst +++ b/docs/source/reference/utils.rst @@ -51,9 +51,9 @@ write_image Iterator Utilities ------------------ -apply_prediction_to_iterator -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: apply_prediction_to_iterator +apply_to_iterator +~~~~~~~~~~~~~~~~~ +.. autofunction:: apply_to_iterator ProgressHook ~~~~~~~~~~~~ diff --git a/examples/classification/eval_imagenet.py b/examples/classification/eval_imagenet.py index 90215c60e3..79e8ba7025 100644 --- a/examples/classification/eval_imagenet.py +++ b/examples/classification/eval_imagenet.py @@ -11,7 +11,7 @@ from chainercv.links import FeaturePredictor from chainercv.links import VGG16 -from chainercv.utils import apply_prediction_to_iterator +from chainercv.utils import apply_to_iterator from chainercv.utils import ProgressHook @@ -42,12 +42,12 @@ def main(): model.to_gpu() print('Model has been prepared. Evaluation starts.') - imgs, pred_values, gt_values = apply_prediction_to_iterator( + in_values, out_values, rest_values = apply_to_iterator( model.predict, iterator, hook=ProgressHook(len(dataset))) - del imgs + del in_values - pred_probs, = pred_values - gt_labels, = gt_values + pred_probs, = out_values + gt_labels, = rest_values accuracy = F.accuracy( np.array(list(pred_probs)), np.array(list(gt_labels))).data diff --git a/examples/detection/eval_voc07.py b/examples/detection/eval_voc07.py index e5d9c81452..c6c182a03f 100644 --- a/examples/detection/eval_voc07.py +++ b/examples/detection/eval_voc07.py @@ -9,7 +9,7 @@ from chainercv.links import FasterRCNNVGG16 from chainercv.links import SSD300 from chainercv.links import SSD512 -from chainercv.utils import apply_prediction_to_iterator +from chainercv.utils import apply_to_iterator from chainercv.utils import ProgressHook @@ -56,13 +56,13 @@ def main(): iterator = iterators.SerialIterator( dataset, args.batchsize, repeat=False, shuffle=False) - imgs, pred_values, gt_values = apply_prediction_to_iterator( + in_values, out_values, rest_values = apply_to_iterator( model.predict, iterator, hook=ProgressHook(len(dataset))) - # delete unused iterator explicitly - del imgs + # delete unused iterators explicitly + del in_values - pred_bboxes, pred_labels, pred_scores = pred_values - gt_bboxes, gt_labels, gt_difficults = gt_values + pred_bboxes, pred_labels, pred_scores = out_values + gt_bboxes, gt_labels, gt_difficults = rest_values result = eval_detection_voc( pred_bboxes, pred_labels, pred_scores, diff --git a/examples/segnet/eval_camvid.py b/examples/segnet/eval_camvid.py index 2e25d3c7db..683eb416ef 100644 --- a/examples/segnet/eval_camvid.py +++ b/examples/segnet/eval_camvid.py @@ -11,7 +11,7 @@ from chainercv.datasets import CamVidDataset from chainercv.evaluations import eval_semantic_segmentation from chainercv.links import SegNetBasic -from chainercv.utils import apply_prediction_to_iterator +from chainercv.utils import apply_to_iterator from chainercv.utils import ProgressHook @@ -60,12 +60,12 @@ def main(): it = chainer.iterators.SerialIterator(test, batch_size=args.batchsize, repeat=False, shuffle=False) - imgs, pred_values, gt_values = apply_prediction_to_iterator( + in_values, out_values, rest_values = apply_to_iterator( model.predict, it, hook=ProgressHook(len(test))) # Delete an iterator of images to save memory usage. - del imgs - pred_labels, = pred_values - gt_labels, = gt_values + del in_values + pred_labels, = out_values + gt_labels, = rest_values result = eval_semantic_segmentation(pred_labels, gt_labels) diff --git a/tests/utils_tests/iterator_tests/test_apply_prediction_to_iterator.py b/tests/utils_tests/iterator_tests/test_apply_prediction_to_iterator.py deleted file mode 100644 index a9a2f62a9d..0000000000 --- a/tests/utils_tests/iterator_tests/test_apply_prediction_to_iterator.py +++ /dev/null @@ -1,108 +0,0 @@ -import numpy as np -from six.moves import zip_longest -import unittest - -import chainer -from chainer.iterators import SerialIterator -from chainer import testing - -from chainercv.utils import apply_prediction_to_iterator - - -@testing.parameterize(*testing.product({ - 'multi_pred_values': [False, True], - 'with_gt_values': [False, True], - 'with_hook': [False, True], -})) -class TestApplyPredictionToIterator(unittest.TestCase): - - def test_apply_prediction_to_iterator(self): - if self.multi_pred_values: - def predict(imgs): - n_img = len(imgs) - return ( - [np.random.uniform(size=(10, 4)) for _ in range(n_img)], - [np.random.uniform(size=10) for _ in range(n_img)], - [np.random.uniform(size=10) for _ in range(n_img)]) - - n_pred_values = 3 - else: - def predict(imgs): - n_img = len(imgs) - return [np.random.uniform(size=(48, 64)) for _ in range(n_img)] - - n_pred_values = 1 - - dataset_imgs = [] - for _ in range(5): - H, W = np.random.randint(8, 16, size=2) - dataset_imgs.append(np.random.randint(0, 256, size=(3, H, W))) - - if self.with_gt_values: - strs = ['a', 'bc', 'def', 'ghij', 'klmno'] - nums = [0, 1, 2, 3, 4] - arrays = [np.random.uniform(size=10) for _ in range(5)] - - dataset = chainer.datasets.TupleDataset( - dataset_imgs, strs, nums, arrays) - dataset_gt_values = (strs, nums, arrays) - else: - dataset = dataset_imgs - dataset_gt_values = tuple() - iterator = SerialIterator(dataset, 2, repeat=False, shuffle=False) - - if self.with_hook: - def hook(imgs, pred_values, gt_values): - self.assertEqual(len(pred_values), n_pred_values) - for pred_vals in pred_values: - self.assertEqual(len(pred_vals), len(imgs)) - - self.assertEqual(len(gt_values), len(dataset_gt_values)) - for gt_vals in gt_values: - self.assertEqual(len(gt_vals), len(imgs)) - else: - hook = None - - imgs, pred_values, gt_values = apply_prediction_to_iterator( - predict, iterator, hook=hook) - - for img, dataset_img in zip_longest(imgs, dataset_imgs): - np.testing.assert_equal(img, dataset_img) - - self.assertEqual(len(pred_values), n_pred_values) - for vals in pred_values: - self.assertEqual(len(list(vals)), len(dataset_imgs)) - - for vals, dataset_vals in zip_longest(gt_values, dataset_gt_values): - for val, dataset_val in zip_longest(vals, dataset_vals): - if isinstance(dataset_val, np.ndarray): - np.testing.assert_equal(val, dataset_val) - else: - self.assertEqual(val, dataset_val) - - -class TestApplyPredictionToIteratorWithInfiniteIterator(unittest.TestCase): - - def test_apply_prediction_to_iterator_with_infinite_iterator(self): - def predict(imgs): - n_img = len(imgs) - return [np.random.uniform(size=(48, 64)) for _ in range(n_img)] - - dataset = [] - for _ in range(5): - H, W = np.random.randint(8, 16, size=2) - dataset.append(np.random.randint(0, 256, size=(3, H, W))) - - iterator = SerialIterator(dataset, 2) - - imgs, pred_values, gt_values = apply_prediction_to_iterator( - predict, iterator) - - for _ in range(10): - next(imgs) - - for _ in range(10): - next(pred_values[0]) - - -testing.run_module(__name__, __file__) diff --git a/tests/utils_tests/iterator_tests/test_apply_to_iterator.py b/tests/utils_tests/iterator_tests/test_apply_to_iterator.py new file mode 100644 index 0000000000..969d9c47f2 --- /dev/null +++ b/tests/utils_tests/iterator_tests/test_apply_to_iterator.py @@ -0,0 +1,134 @@ +import numpy as np +from six.moves import zip_longest +import unittest + +import chainer +from chainer.iterators import SerialIterator +from chainer import testing + +from chainercv.utils import apply_to_iterator + + +@testing.parameterize(*testing.product({ + 'multi_in_values': [False, True], + 'multi_out_values': [False, True], + 'with_rest_values': [False, True], + 'with_hook': [False, True], +})) +class TestApplyToIterator(unittest.TestCase): + + def test_apply_to_iterator(self): + if self.multi_in_values: + n_input = 2 + else: + n_input = 1 + + in_values_expect = [] + for _ in range(n_input): + in_value = [] + for _ in range(5): + H, W = np.random.randint(8, 16, size=2) + in_value.append(np.random.randint(0, 256, size=(3, H, W))) + in_values_expect.append(in_value) + in_values_expect = tuple(in_values_expect) + + if self.multi_out_values: + def func(*in_values): + n_sample = len(in_values[0]) + return ( + [np.random.uniform(size=(10, 4)) for _ in range(n_sample)], + [np.random.uniform(size=10) for _ in range(n_sample)], + [np.random.uniform(size=10) for _ in range(n_sample)]) + + n_output = 3 + else: + def func(*in_values): + n_sample = len(in_values[0]) + return [np.random.uniform(size=(48, 64)) + for _ in range(n_sample)] + + n_output = 1 + + if self.with_rest_values: + strs = ['a', 'bc', 'def', 'ghij', 'klmno'] + nums = [0, 1, 2, 3, 4] + arrays = [np.random.uniform(size=10) for _ in range(5)] + rest_values_expect = (strs, nums, arrays) + n_rest = 3 + + dataset = chainer.datasets.TupleDataset( + *(in_values_expect + rest_values_expect)) + else: + rest_values_expect = () + n_rest = 0 + + dataset = list(zip(*in_values_expect)) + + iterator = SerialIterator(dataset, 2, repeat=False, shuffle=False) + + if self.with_hook: + def hook(in_values, out_values, rest_values): + n_sample = len(in_values[0]) + + self.assertEqual(len(in_values), n_input) + for in_vals in in_values: + self.assertEqual(len(in_vals), n_sample) + + self.assertEqual(len(out_values), n_output) + for out_vals in out_values: + self.assertEqual(len(out_vals), n_sample) + + self.assertEqual(len(rest_values), n_rest) + for rest_vals in rest_values: + self.assertEqual(len(rest_vals), n_sample) + else: + hook = None + + in_values, out_values, rest_values = apply_to_iterator( + func, iterator, n_input=n_input, hook=hook) + + self.assertEqual(len(in_values), n_input) + for in_vals, in_vals_expect in \ + zip_longest(in_values, in_values_expect): + for in_val, in_val_expect in zip_longest(in_vals, in_vals_expect): + np.testing.assert_equal(in_val, in_val_expect) + + self.assertEqual(len(out_values), n_output) + for out_vals in out_values: + self.assertEqual(len(list(out_vals)), len(dataset)) + + self.assertEqual(len(rest_values), n_rest) + for rest_vals, rest_vals_expect in \ + zip_longest(rest_values, rest_values_expect): + for rest_val, rest_val_expect in \ + zip_longest(rest_vals, rest_vals_expect): + if isinstance(rest_val_expect, np.ndarray): + np.testing.assert_equal(rest_val, rest_val_expect) + else: + self.assertEqual(rest_val, rest_val_expect) + + +class TestApplyToIteratorWithInfiniteIterator(unittest.TestCase): + + def test_apply_to_iterator_with_infinite_iterator(self): + def func(*in_values): + n_sample = len(in_values[0]) + return [np.random.uniform(size=(48, 64)) for _ in range(n_sample)] + + dataset = [] + for _ in range(5): + H, W = np.random.randint(8, 16, size=2) + dataset.append(np.random.randint(0, 256, size=(3, H, W))) + + iterator = SerialIterator(dataset, 2) + + in_values, out_values, rest_values = apply_to_iterator(func, iterator) + + for _ in range(10): + next(in_values[0]) + + for _ in range(10): + next(out_values[0]) + + +testing.run_module(__name__, __file__) diff --git a/tests/utils_tests/iterator_tests/test_progress_hook.py b/tests/utils_tests/iterator_tests/test_progress_hook.py index 14ab65d071..2ef4f43599 100644 --- a/tests/utils_tests/iterator_tests/test_progress_hook.py +++ b/tests/utils_tests/iterator_tests/test_progress_hook.py @@ -4,18 +4,18 @@ from chainer.iterators import SerialIterator from chainer import testing -from chainercv.utils import apply_prediction_to_iterator +from chainercv.utils import apply_to_iterator from chainercv.utils import ProgressHook class TestProgressHook(unittest.TestCase): def setUp(self): - def predict(imgs): - n_img = len(imgs) - return [np.random.uniform() for _ in range(n_img)] + def func(*in_values): + n_sample = len(in_values[0]) + return [np.random.uniform() for _ in range(n_sample)] - self.predict = predict + self.func = func self.dataset = [] for _ in range(5): @@ -25,22 +25,22 @@ def predict(imgs): def test_progress_hook(self): iterator = SerialIterator(self.dataset, 2, repeat=False) - imgs, pred_values, gt_values = apply_prediction_to_iterator( - self.predict, iterator, + in_values, out_values, rest_values = apply_to_iterator( + self.func, iterator, hook=ProgressHook(n_total=len(self.dataset))) # consume all data - for _ in imgs: + for _ in in_values[0]: pass def test_progress_hook_with_infinite_iterator(self): iterator = SerialIterator(self.dataset, 2) - imgs, pred_values, gt_values = apply_prediction_to_iterator( - self.predict, iterator, hook=ProgressHook()) + in_values, out_values, rest_values = apply_to_iterator( + self.func, iterator, hook=ProgressHook()) for _ in range(10): - next(imgs) + next(in_values[0]) testing.run_module(__name__, __file__)