From 73f0c7630b06e898c4296558021cbe055f0e7f1a Mon Sep 17 00:00:00 2001 From: zytx121 <592267829@qq.com> Date: Sun, 13 Nov 2022 09:20:14 +0800 Subject: [PATCH 1/2] update --- mmrotate/visualization/local_visualizer.py | 11 ++++++++++- tests/test_visualization/test_local_visualizer.py | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/mmrotate/visualization/local_visualizer.py b/mmrotate/visualization/local_visualizer.py index 735d2d5e8..314d66c93 100644 --- a/mmrotate/visualization/local_visualizer.py +++ b/mmrotate/visualization/local_visualizer.py @@ -7,8 +7,10 @@ from mmdet.visualization import DetLocalVisualizer, jitter_color from mmdet.visualization.palette import _get_adaptive_scales from mmengine.structures import InstanceData +from torch import Tensor from mmrotate.registry import VISUALIZERS +from mmrotate.structures.bbox import QuadriBoxes, RotatedBoxes from .palette import get_palette @@ -68,7 +70,14 @@ def _draw_instances(self, image: np.ndarray, instances: ['InstanceData'], bbox_palette = get_palette(bbox_color, max_label + 1) colors = [bbox_palette[label] for label in labels] - # convert to qbox + if isinstance(bboxes, Tensor): + if bboxes.size(-1) == 5: + bboxes = RotatedBoxes(bboxes) + elif bboxes.size(-1) == 8: + bboxes = QuadriBoxes(bboxes) + else: + NotImplementedError + polygons = bboxes.convert_to('qbox').tensor polygons = polygons.reshape(-1, 4, 2).numpy() polygons = [p for p in polygons] diff --git a/tests/test_visualization/test_local_visualizer.py b/tests/test_visualization/test_local_visualizer.py index ba5442b18..15667bdd1 100644 --- a/tests/test_visualization/test_local_visualizer.py +++ b/tests/test_visualization/test_local_visualizer.py @@ -37,7 +37,7 @@ def test_add_datasample(self): # test gt_instances gt_instances = InstanceData() - gt_instances.bboxes = RotatedBoxes(_rand_rbboxes(num_bboxes, h, w)) + gt_instances.bboxes = _rand_rbboxes(num_bboxes, h, w) gt_instances.masks = _fake_masks(num_bboxes, h, w) gt_instances.labels = torch.randint(0, num_class, (num_bboxes, )) det_data_sample = DetDataSample() From 86f97d875cf5df39afb54fc47e9702ad6a4dae4e Mon Sep 17 00:00:00 2001 From: zytx121 <592267829@qq.com> Date: Tue, 15 Nov 2022 09:02:54 +0800 Subject: [PATCH 2/2] Update local_visualizer.py --- mmrotate/visualization/local_visualizer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mmrotate/visualization/local_visualizer.py b/mmrotate/visualization/local_visualizer.py index 314d66c93..3d54e9fce 100644 --- a/mmrotate/visualization/local_visualizer.py +++ b/mmrotate/visualization/local_visualizer.py @@ -76,7 +76,10 @@ def _draw_instances(self, image: np.ndarray, instances: ['InstanceData'], elif bboxes.size(-1) == 8: bboxes = QuadriBoxes(bboxes) else: - NotImplementedError + raise TypeError( + 'Require the shape of `bboxes` to be (n, 5) ' + 'or (n, 8), but get `bboxes` with shape being ' + f'{bboxes.shape}.') polygons = bboxes.convert_to('qbox').tensor polygons = polygons.reshape(-1, 4, 2).numpy()