Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] refactor HRSCDataset #457

Merged
merged 13 commits into from
Aug 23, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
380 changes: 146 additions & 234 deletions mmrotate/datasets/hrsc.py

Large diffs are not rendered by default.

31 changes: 18 additions & 13 deletions mmrotate/evaluation/metrics/dota_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ class DOTAMetric(BaseMetric):
metric (str | list[str]): Metrics to be evaluated. Only support
'mAP' now. If is list, the first setting in the list will
be used to evaluate metric.
proposal_nums (Sequence[int]): Proposal number used for evaluating
recalls, such as recall@100, recall@1000.
Defaults to (100, 300, 1000).
format_only (bool): Format the output results without perform
evaluation. It is useful when you want to format the result
to a specific format. Defaults to False.
Expand All @@ -50,7 +47,12 @@ class DOTAMetric(BaseMetric):
patches' results.
iou_thr (float): IoU threshold of ``nms_rotated`` used in merge
patches. Defaults to 0.1.
version (str): Angle representations. Defaults to 'oc'.
angle_version (str): Angle representations. Defaults to 'oc'.
eval_mode (str): 'area' or '11points', 'area' means calculating the
area under precision-recall curve, '11points' means calculating
the average precision of recalls at [0, 0.1, ..., 1].
The PASCAL VOC2007 defaults to use '11points', while PASCAL
VOC2012 defaults to use 'area'. Defaults to '11points'.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
Expand All @@ -66,12 +68,12 @@ def __init__(self,
iou_thrs: Union[float, List[float]] = 0.5,
scale_ranges: Optional[List[tuple]] = None,
metric: Union[str, List[str]] = 'mAP',
proposal_nums: Sequence[int] = (100, 300, 1000),
format_only: bool = False,
outfile_prefix: Optional[str] = None,
merge_patches: bool = False,
iou_thr: float = 0.1,
version: str = 'oc',
angle_version: str = 'oc',
eval_mode: str = '11points',
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)
Expand All @@ -83,12 +85,10 @@ def __init__(self,
if not isinstance(metric, str):
assert len(metric) == 1
metric = metric[0]
allowed_metrics = ['recall', 'mAP']
allowed_metrics = ['mAP']
if metric not in allowed_metrics:
raise KeyError(
f"metric should be one of 'recall', 'mAP', but got {metric}.")
raise KeyError(f"metric should be one of 'mAP', but got {metric}.")
self.metric = metric
self.proposal_nums = proposal_nums

self.format_only = format_only
if self.format_only:
Expand All @@ -99,7 +99,10 @@ def __init__(self,
self.outfile_prefix = outfile_prefix
self.merge_patches = merge_patches
self.iou_thr = iou_thr
self.version = version
self.angle_version = angle_version
assert eval_mode in ['area, 11points'], \
'Unrecognized mode, only "area" and "11points" are supported'
self.use_07_metric = True if eval_mode == '11points' else False

def merge_results(self, results: Sequence[dict],
outfile_prefix: str) -> str:
Expand Down Expand Up @@ -163,7 +166,7 @@ def merge_results(self, results: Sequence[dict],
for f, dets in zip(file_objs, dets_per_cls):
if dets.size == 0:
continue
bboxes = obb2poly_np(dets, self.version)
bboxes = obb2poly_np(dets, self.angle_version)
for bbox in bboxes:
txt_element = [img_id, str(bbox[-1])
] + [f'{p:.2f}' for p in bbox[:-1]]
Expand Down Expand Up @@ -300,15 +303,17 @@ def compute_metrics(self, results: list) -> dict:
if self.metric == 'mAP':
assert isinstance(self.iou_thrs, list)
dataset_name = self.dataset_meta['CLASSES']
dets = [pred['pred_bbox_scores'] for pred in preds]

mean_aps = []
for iou_thr in self.iou_thrs:
logger.info(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}')
mean_ap, _ = eval_rbbox_map(
preds['pred_bbox_scores'],
dets,
gts,
scale_ranges=self.scale_ranges,
iou_thr=iou_thr,
use_07_metric=self.use_07_metric,
dataset=dataset_name,
logger=logger)
mean_aps.append(mean_ap)
Expand Down
58 changes: 58 additions & 0 deletions tests/data/hrsc/FullDataSet/Annotations/100000006.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
<HRSC_Image>
<Img_ID>100000006</Img_ID>
<Place_ID>100000001</Place_ID>
<Source_ID>100000001</Source_ID>
<Img_NO>100000006</Img_NO>
<Img_FileName>100000006</Img_FileName>
<Img_FileFmt>bmp</Img_FileFmt>
<Img_Date>1900-01-01</Img_Date>
<Img_CusType>sealand</Img_CusType>
<Img_Des>
</Img_Des>
<Img_Location>69.040297,33.070036</Img_Location>
<Img_SizeWidth>1172</Img_SizeWidth>
<Img_SizeHeight>816</Img_SizeHeight>
<Img_SizeDepth>3</Img_SizeDepth>
<Img_Resolution>1.07</Img_Resolution>
<Img_Resolution_Layer>18</Img_Resolution_Layer>
<Img_Scale>100</Img_Scale>
<Img_SclPxlNum>
</Img_SclPxlNum>
<segmented>0</segmented>
<Img_Havemask>0</Img_Havemask>
<Img_MaskFileName>
</Img_MaskFileName>
<Img_MaskFileFmt>
</Img_MaskFileFmt>
<Img_MaskType>
</Img_MaskType>
<Img_SegFileName>
</Img_SegFileName>
<Img_SegFileFmt>
</Img_SegFileFmt>
<Img_Rotation>000d</Img_Rotation>
<Annotated>1</Annotated>
<HRSC_Objects>
<HRSC_Object>
<Object_ID>100000006</Object_ID>
<Class_ID>100000013</Class_ID>
<Object_NO>100000006</Object_NO>
<truncated>0</truncated>
<difficult>0</difficult>
<box_xmin>119</box_xmin>
<box_ymin>75</box_ymin>
<box_xmax>587</box_xmax>
<box_ymax>789</box_ymax>
<mbox_cx>341.2143</mbox_cx>
<mbox_cy>443.3325</mbox_cy>
<mbox_w>778.4297</mbox_w>
<mbox_h>178.2595</mbox_h>
<mbox_ang>-1.122944</mbox_ang>
<segmented>0</segmented>
<seg_color>
</seg_color>
<header_x>143</header_x>
<header_y>776</header_y>
</HRSC_Object>
</HRSC_Objects>
</HRSC_Image>
1 change: 1 addition & 0 deletions tests/data/hrsc/demo.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
100000006
6 changes: 3 additions & 3 deletions tests/test_datasets/test_dota.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class TestDOTADataset(unittest.TestCase):

def test_dota_with_ann_file(self):
dataset = DOTADataset(
data_root='tests/data/',
data_root='tests/data/dota/',
ann_file='labelTxt/',
data_prefix=dict(img_path='images/'),
filter_cfg=dict(
Expand All @@ -22,7 +22,7 @@ def test_dota_with_ann_file(self):
self.assertEqual(data_list[0]['img_id'], 'P2805__1024__0___0')
self.assertEqual(data_list[0]['file_name'], 'P2805__1024__0___0.png')
self.assertEqual(data_list[0]['img_path'],
'tests/data/images/P2805__1024__0___0.png')
'tests/data/dota/images/P2805__1024__0___0.png')
self.assertEqual(len(data_list[0]['instances']), 4)
self.assertEqual(dataset.get_cat_ids(0), [0, 0, 0, 0])

Expand All @@ -41,6 +41,6 @@ def test_dota_without_ann_file(self):
self.assertEqual(data_list[0]['img_id'], 'P2805__1024__0___0')
self.assertEqual(data_list[0]['file_name'], 'P2805__1024__0___0.png')
self.assertEqual(data_list[0]['img_path'],
'tests/data/images/P2805__1024__0___0.png')
'tests/data/dota/images/P2805__1024__0___0.png')
self.assertEqual(len(data_list[0]['instances']), 1)
self.assertEqual(dataset.get_cat_ids(0), [[]])
54 changes: 54 additions & 0 deletions tests/test_datasets/test_hrsc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest

from mmrotate.datasets import HRSCDataset


class TestHRSCDataset(unittest.TestCase):

def test_hrsc(self):
dataset = HRSCDataset(
data_root='tests/data/hrsc/',
ann_file='demo.txt',
data_prefix=dict(sub_data_root='FullDataSet/'),
filter_cfg=dict(filter_empty_gt=True, min_size=32),
pipeline=[])
dataset.full_init()
self.assertEqual(len(dataset), 1)

data_list = dataset.load_data_list()
self.assertEqual(len(data_list), 1)
self.assertEqual(data_list[0]['img_id'], '100000006')
self.assertEqual(
data_list[0]['img_path'],
'tests/data/hrsc/FullDataSet/AllImages/100000006.bmp')
self.assertEqual(
data_list[0]['xml_path'],
'tests/data/hrsc/FullDataSet/Annotations/100000006.xml')
self.assertEqual(len(data_list[0]['instances']), 1)
self.assertEqual(dataset.get_cat_ids(0), [0])
self.assertEqual(dataset._metainfo['CLASSES'], ('ship', ))

def test_hrsc_classwise(self):
dataset = HRSCDataset(
data_root='tests/data/hrsc/',
ann_file='demo.txt',
data_prefix=dict(sub_data_root='FullDataSet/'),
classwise=True,
filter_cfg=dict(filter_empty_gt=True, min_size=32),
pipeline=[])
dataset.full_init()
self.assertEqual(len(dataset), 1)

data_list = dataset.load_data_list()
self.assertEqual(len(data_list), 1)
self.assertEqual(data_list[0]['img_id'], '100000006')
self.assertEqual(
data_list[0]['img_path'],
'tests/data/hrsc/FullDataSet/AllImages/100000006.bmp')
self.assertEqual(
data_list[0]['xml_path'],
'tests/data/hrsc/FullDataSet/Annotations/100000006.xml')
self.assertEqual(len(data_list[0]['instances']), 1)
self.assertEqual(dataset.get_cat_ids(0), [12])
self.assertEqual(len(dataset._metainfo['CLASSES']), 31)