Skip to content

Commit

Permalink
add segmentation reference consistency tests (#6591)
Browse files Browse the repository at this point in the history
* add segmentation reference consistency tests

* fall back to smoke tests for resize

* add test for RandomCrop

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
  • Loading branch information
pmeier and datumbox authored Sep 23, 2022
1 parent 0a946d5 commit 7046e56
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 8 deletions.
3 changes: 2 additions & 1 deletion test/prototype_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def __init__(

def _process_inputs(self, actual, expected, *, id, allow_subclasses):
actual, expected = [
to_image_tensor(input) if not isinstance(input, torch.Tensor) else input for input in [actual, expected]
to_image_tensor(input) if not isinstance(input, torch.Tensor) else features.Image(input)
for input in [actual, expected]
]
# This broadcast is needed, because `features.Mask`'s can have a 2D shape, but converting the equivalent PIL
# image to a tensor adds a singleton leading dimension.
Expand Down
182 changes: 175 additions & 7 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import enum
import inspect
import random
from collections import defaultdict
from importlib.machinery import SourceFileLoader
from pathlib import Path

Expand All @@ -16,13 +18,15 @@
make_image,
make_images,
make_label,
make_segmentation_mask,
)
from torchvision import transforms as legacy_transforms
from torchvision._utils import sequence_to_str
from torchvision.prototype import features, transforms as prototype_transforms
from torchvision.prototype.transforms import functional as F
from torchvision.prototype.transforms._utils import query_chw
from torchvision.prototype.transforms.functional import to_image_pil


DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)])


Expand Down Expand Up @@ -852,10 +856,12 @@ def test_aa(self, inpt, interpolation):
assert_equal(expected_output, output)


# Import reference detection transforms here for consistency checks
# torchvision/references/detection/transforms.py
ref_det_filepath = Path(__file__).parent.parent / "references" / "detection" / "transforms.py"
det_transforms = SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module()
def import_transforms_from_references(reference):
ref_det_filepath = Path(__file__).parent.parent / "references" / reference / "transforms.py"
return SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module()


det_transforms = import_transforms_from_references("detection")


class TestRefDetTransforms:
Expand All @@ -873,7 +879,7 @@ def make_datapoints(self, with_mask=True):

yield (pil_image, target)

tensor_image = torch.randint(0, 256, size=(3, *size), dtype=torch.uint8)
tensor_image = torch.Tensor(make_image(size=size, color_space=features.ColorSpace.RGB))
target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
Expand All @@ -883,7 +889,7 @@ def make_datapoints(self, with_mask=True):

yield (tensor_image, target)

feature_image = features.Image(torch.randint(0, 256, size=(3, *size), dtype=torch.uint8))
feature_image = make_image(size=size, color_space=features.ColorSpace.RGB)
target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
Expand Down Expand Up @@ -927,3 +933,165 @@ def test_transform(self, t_ref, t, data_kwargs):
expected_output = t_ref(*dp)

assert_equal(expected_output, output)


seg_transforms = import_transforms_from_references("segmentation")


# We need this transform for two reasons:
# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
# counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
class PadIfSmaller(prototype_transforms.Transform):
def __init__(self, size, fill=0):
super().__init__()
self.size = size
self.fill = prototype_transforms._geometry._setup_fill_arg(fill)

def _get_params(self, sample):
_, height, width = query_chw(sample)
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
needs_padding = any(padding)
return dict(padding=padding, needs_padding=needs_padding)

def _transform(self, inpt, params):
if not params["needs_padding"]:
return inpt

fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)

return F.pad(inpt, padding=params["padding"], fill=fill)


class TestRefSegTransforms:
def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8):
size = (256, 640)
num_categories = 21

conv_fns = []
if supports_pil:
conv_fns.append(to_image_pil)
conv_fns.extend([torch.Tensor, lambda x: x])

for conv_fn in conv_fns:
feature_image = make_image(size=size, color_space=features.ColorSpace.RGB, dtype=image_dtype)
feature_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)

dp = (conv_fn(feature_image), feature_mask)
dp_ref = (
to_image_pil(feature_image) if supports_pil else torch.Tensor(feature_image),
to_image_pil(feature_mask),
)

yield dp, dp_ref

def set_seed(self, seed=12):
torch.manual_seed(seed)
random.seed(seed)

def check(self, t, t_ref, data_kwargs=None):
for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()):

self.set_seed()
output = t(dp)

self.set_seed()
expected_output = t_ref(*dp_ref)

assert_equal(output, expected_output)

@pytest.mark.parametrize(
("t_ref", "t", "data_kwargs"),
[
(
seg_transforms.RandomHorizontalFlip(flip_prob=1.0),
prototype_transforms.RandomHorizontalFlip(p=1.0),
dict(),
),
(
seg_transforms.RandomHorizontalFlip(flip_prob=0.0),
prototype_transforms.RandomHorizontalFlip(p=0.0),
dict(),
),
(
seg_transforms.RandomCrop(size=480),
prototype_transforms.Compose(
[
PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {features.Mask: 255})),
prototype_transforms.RandomCrop(size=480),
]
),
dict(),
),
(
seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
prototype_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
dict(supports_pil=False, image_dtype=torch.float),
),
],
)
def test_common(self, t_ref, t, data_kwargs):
self.check(t, t_ref, data_kwargs)

def check_resize(self, mocker, t_ref, t):
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize")
mock_ref = mocker.patch("torchvision.transforms.functional.resize")

for dp, dp_ref in self.make_datapoints():
mock.reset_mock()
mock_ref.reset_mock()

self.set_seed()
t(dp)
assert mock.call_count == 2
assert all(
actual is expected
for actual, expected in zip([call_args[0][0] for call_args in mock.call_args_list], dp)
)

self.set_seed()
t_ref(*dp_ref)
assert mock_ref.call_count == 2
assert all(
actual is expected
for actual, expected in zip([call_args[0][0] for call_args in mock_ref.call_args_list], dp_ref)
)

for args_kwargs, args_kwargs_ref in zip(mock.call_args_list, mock_ref.call_args_list):
assert args_kwargs[0][1] == [args_kwargs_ref[0][1]]

def test_random_resize_train(self, mocker):
base_size = 520
min_size = base_size // 2
max_size = base_size * 2

randint = torch.randint

def patched_randint(a, b, *other_args, **kwargs):
if kwargs or len(other_args) > 1 or other_args[0] != ():
return randint(a, b, *other_args, **kwargs)

return random.randint(a, b)

# We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported
# normally
t = prototype_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True)
mocker.patch(
"torchvision.prototype.transforms._geometry.torch.randint",
new=patched_randint,
)

t_ref = seg_transforms.RandomResize(min_size=min_size, max_size=max_size)

self.check_resize(mocker, t_ref, t)

def test_random_resize_eval(self, mocker):
torch.manual_seed(0)
base_size = 520

t = prototype_transforms.Resize(size=base_size, antialias=True)

t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size)

self.check_resize(mocker, t_ref, t)

0 comments on commit 7046e56

Please sign in to comment.