-
Notifications
You must be signed in to change notification settings - Fork 302
Improve apply_prediction #523
Changes from 9 commits
4179b00
6ebf6e6
4afd154
14d0973
cdbc657
592d30a
4db7d8b
2a6a4a2
7a57c6c
77acf2e
2d3cb55
a4e7c2e
2e052a4
f920744
cc91196
0bd2f6d
2e449a3
583ed87
09f7ade
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from chainercv.utils.iterator.apply_prediction_to_iterator import apply_prediction_to_iterator # NOQA | ||
from chainercv.utils.iterator.apply_to_batch import apply_to_batch # NOQA | ||
from chainercv.utils.iterator.progress_hook import ProgressHook # NOQA | ||
from chainercv.utils.iterator.unzip import unzip # NOQA |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
from chainercv.utils.iterator.unzip import unzip | ||
|
||
|
||
def apply_to_batch(func, iterator, n_input=1, hook=None): | ||
"""Apply a function/method to an iterator of batches. | ||
|
||
This function applies a function/method to an iterator of batches. | ||
It assumes that the iterator returns a batch of input data or | ||
a batch of tuples whose first elements ara input data. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought it was better to describe the internals of this function more in the beginning. This function assumes that the iterator iterates over a collection of tuples that contain inputs to Here is an illustration of the expected behavior of
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This explanation sounds this function returns three iterators. However, it is not correct. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right. That is my mistake. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, I will try to merge your solution into docstring. |
||
|
||
>>> batch = next(iterator) | ||
>>> # batch: [in_val] | ||
or | ||
>>> # batch: [(in_val0, in_val1, ...)] | ||
or | ||
>>> # batch: [(in_val0, in_val1, ..., rest_val0, rest_val1, ...)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
This function applies :func:`func` to batch(es) of input data and gets | ||
computed value(s). :func:`func` should take a batch of data and | ||
return a batch of computed values | ||
or a tuple of batches of computed values. | ||
|
||
>>> out_vals = func(in_val0, in_val1, ...) | ||
>>> # out_vals: [out_val] | ||
or | ||
>>> out_vals0, out_vals1, ... = func(in_val0, in_val1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think these examples are not necessary. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would like to show the returned values should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah. I see. So you wanted to emphasize that the function returns tuple of lists. BTW, shouldn't function take list of values? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right. That's my mistake. |
||
>>> # out_vals0: [out_val0] | ||
>>> # out_vals1: [out_val1] | ||
|
||
Here is an exmple, which applies a pretrained Faster R-CNN to | ||
PASCAL VOC dataset. | ||
|
||
>>> from chainer import iterators | ||
>>> | ||
>>> from chainercv.datasets import VOCDetectionDataset | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. VOCBBoxDataset. This is my mistake. Sorry. |
||
>>> from chainercv.links import FasterRCNNVGG16 | ||
>>> from chainercv.utils import apply_to_batch | ||
>>> | ||
>>> dataset = VOCDetectionDataset(year='2007', split='test') | ||
>>> # next(iterator) -> [(img, gt_bbox, gt_label)] | ||
>>> iterator = iterators.SerialIterator( | ||
... dataset, 2, repeat=False, shuffle=False) | ||
>>> | ||
>>> # model.predict([img]) -> ([pred_bbox], [pred_label], [pred_score]) | ||
>>> model = FasterRCNNVGG16(pretrained_model='voc07') | ||
>>> | ||
>>> in_values, out_values, rest_values = apply_to_batch( | ||
... model.predict, iterator) | ||
>>> | ||
>>> # in_values contains one iterator | ||
>>> imgs, = in_values | ||
>>> # out_values contains three iterators | ||
>>> pred_bboxes, pred_labels, pred_scores = out_values | ||
>>> # rest_values contains two iterators | ||
>>> gt_bboxes, gt_labels = rest_values | ||
|
||
Args: | ||
func: A callable that takes batch(es) of input data and returns | ||
computed data. | ||
iterator (chainer.Iterator): An iterator of batch. | ||
The first :obj:`n_input` elements in each sample are | ||
treated as input values. They are passed to :obj:`func`. | ||
n_input (int): The number of input data. The default value is :obj:`1`. | ||
hook: A callable that is called after each iteration. | ||
:obj:`in_values`, :obj:`out_values`, and :obj:`rest_values` | ||
are passed as arguments. | ||
Note that these values do not contain data from the previous | ||
iterations. | ||
|
||
Returns: | ||
Three tuples of iterators: | ||
This function returns three tuples of iterators: | ||
:obj:`in_values`, :obj:`out_values` and :obj:`rest_values`. | ||
|
||
* :obj:`in_values`: A tuple of iterators. Each iterator \ | ||
returns a corresponding input value. \ | ||
For example, if :func:`func` takes \ | ||
:obj:`[in_val0], [in_val1]`, :obj:`next(in_values[0])` \ | ||
and :obj:`next(in_values[1])` will be \ | ||
:obj:`in_val0` and :obj:`in_val1`. | ||
* :obj:`out_values`: A tuple of iterators. Each iterator \ | ||
returns a corresponding computed value. \ | ||
For example, if :func:`func` returns \ | ||
:obj:`([out_val0], [out_val1])`, :obj:`next(out_values[0])` \ | ||
and :obj:`next(out_values[1])` will be \ | ||
:obj:`out_val0` and :obj:`out_val1`. | ||
* :obj:`rest_values`: A tuple of iterators. Each iterator \ | ||
returns a corresponding rest value. \ | ||
For example, if the :obj:`iterator` returns \ | ||
:obj:`[(in_val0, in_val1, rest_val0, rest_val1)]`, \ | ||
:obj:`next(rest_values[0])` \ | ||
and :obj:`next(rest_values[1])` will be \ | ||
:obj:`rest_val0` and :obj:`rest_val1`. \ | ||
If the input \ | ||
iterator does not give any rest values, this tuple \ | ||
will be empty. | ||
""" | ||
|
||
in_values, out_values, rest_values = unzip( | ||
_apply(func, iterator, n_input, hook)) | ||
|
||
# in_values: iter of ([in_val0], [in_val1], ...) | ||
# -> (iter of in_val0, iter of in_val1, ...) | ||
in_values = tuple(map(_flatten, unzip(in_values))) | ||
|
||
# out_values: iter of ([out_val0], [out_val1], ...) | ||
# -> (iter of out_val0, iter of out_val1, ...) | ||
out_values = tuple(map(_flatten, unzip(out_values))) | ||
|
||
# rest_values: iter of ([rest_val0], [rest_val1], ...) | ||
# -> (iter of rest_val0, iter of rest_val1, ...) | ||
rest_values = tuple(map(_flatten, unzip(rest_values))) | ||
|
||
return in_values, out_values, rest_values | ||
|
||
|
||
def _apply(func, iterator, n_input, hook): | ||
for batch in iterator: | ||
# batch: [(in_val0, in_val1, ... , rest_val0, rest_val1, ...)] or | ||
# [in_val] | ||
|
||
in_values = [] | ||
rest_values = [] | ||
for sample in batch: | ||
if isinstance(sample, tuple): | ||
in_values.append(sample[0:n_input]) | ||
rest_values.append(sample[n_input:]) | ||
else: | ||
in_values.append((sample,)) | ||
rest_values.append(tuple()) | ||
|
||
# in_values: [(in_val0, in_val1, ...)] | ||
# -> ([in_val0], [in_val1], ...) | ||
in_values = tuple(list(v) for v in zip(*in_values)) | ||
|
||
# rest_values: [(rest_val0, rest_val1, ...)] | ||
# -> ([rest_val0], [rest_val1], ...) | ||
rest_values = tuple(list(v) for v in zip(*rest_values)) | ||
|
||
# out_values: ([out_val0], [out_val1], ...) or [out_val] | ||
out_values = func(*in_values) | ||
if not isinstance(out_values, tuple): | ||
# pred_values: [out_val] -> ([out_val],) | ||
out_values = out_values, | ||
|
||
if hook: | ||
hook(in_values, out_values, rest_values) | ||
|
||
yield in_values, out_values, rest_values | ||
|
||
|
||
def _flatten(iterator): | ||
return (sample for batch in iterator for sample in batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not directly related to the change proposed in this PR, but this sentence is incorrect.
We apply a function to batches and not an iterator.
Thus,
Apply a function/method to batches of data from an iterator
is more precise.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked up the usage of
apply
on the internet.https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.apply.html