From d2c2488c77b299b182dbf2ce8eab558ec696b636 Mon Sep 17 00:00:00 2001 From: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> Date: Mon, 2 Nov 2020 07:56:56 +0000 Subject: [PATCH] add unittests, test=develop --- .../fluid/tests/unittests/test_op_version.py | 67 ++++++++++++++++++ python/paddle/utils/__init__.py | 1 + python/paddle/utils/op_version.py | 70 +++++++++++++++++++ 3 files changed, 138 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_op_version.py create mode 100644 python/paddle/utils/op_version.py diff --git a/python/paddle/fluid/tests/unittests/test_op_version.py b/python/paddle/fluid/tests/unittests/test_op_version.py new file mode 100644 index 00000000000000..06367310daaf7a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_op_version.py @@ -0,0 +1,67 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import paddle.utils as utils +import paddle.fluid as fluid + + +class OpLastCheckpointCheckerTest(unittest.TestCase): + def __init__(self, methodName='runTest'): + super(OpLastCheckpointCheckerTest, self).__init__(methodName) + self.checker = utils.OpLastCheckpointChecker() + + def test_op_attr_info(self): + update_type = fluid.core.OpUpdateType.kNewAttr + info_list = self.checker.filter_updates('arg_max', update_type, + 'flatten') + self.assertTrue(info_list) + self.assertTrue(info_list[0].name()) + self.assertTrue(info_list[0].default_value() == False) + self.assertTrue(info_list[0].remark()) + + def test_op_input_output_info(self): + update_type = fluid.core.OpUpdateType.kNewInput + info_list = self.checker.filter_updates('roi_align', update_type, + 'RoisNum') + self.assertTrue(info_list) + self.assertTrue(info_list[0].name()) + self.assertTrue(info_list[0].remark()) + + def test_op_bug_fix_info(self): + update_type = fluid.core.OpUpdateType.kBugfixWithBehaviorChanged + info_list = self.checker.filter_updates('leaky_relu', update_type) + self.assertTrue(info_list) + self.assertTrue(info_list[0].remark()) + + +class OpVersionTest(unittest.TestCase): + def __init__(self, methodName='runTest'): + super(OpVersionTest, self).__init__(methodName) + self.vmap = fluid.core.get_op_version_map() + + def test_checkpoints(self): + version_id = self.vmap['arg_max'].version_id() + checkpoints = self.vmap['arg_max'].checkpoints() + self.assertTrue(version_id) + self.assertTrue(checkpoints) + self.assertTrue(checkpoints[0].note()) + self.assertTrue(checkpoints[0].version_desc().infos()) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/utils/__init__.py b/python/paddle/utils/__init__.py index 9d7a05131ffa13..faf0fd4984d7ca 100644 --- a/python/paddle/utils/__init__.py +++ b/python/paddle/utils/__init__.py @@ -17,6 +17,7 @@ from .profiler import get_profiler from .deprecated import deprecated from .lazy_import import try_import +from .op_version import OpLastCheckpointChecker from .install_check import run_check from ..fluid.framework import unique_name from ..fluid.framework import load_op_library diff --git a/python/paddle/utils/op_version.py b/python/paddle/utils/op_version.py new file mode 100644 index 00000000000000..68acc9de081518 --- /dev/null +++ b/python/paddle/utils/op_version.py @@ -0,0 +1,70 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..fluid import core + +__all__ = ['OpLastCheckpointChecker'] + + +def Singleton(cls): + _instance = {} + + def _singleton(*args, **kargs): + if cls not in _instance: + _instance[cls] = cls(*args, **kargs) + return _instance[cls] + + return _singleton + + +class OpUpdateInfoHelper(object): + def __init__(self, info): + self._info = info + + def verify_key_value(self, name=''): + result = False + key_funcs = { + core.OpAttrInfo: 'name', + core.OpInputOutputInfo: 'name', + } + if name == '': + result = True + elif type(self._info) in key_funcs: + if getattr(self._info, key_funcs[type(self._info)])() == name: + result = True + return result + + +@Singleton +class OpLastCheckpointChecker(object): + def __init__(self): + self.raw_version_map = core.get_op_version_map() + self.checkpoints_map = {} + self._construct_map() + + def _construct_map(self): + for op_name in self.raw_version_map: + last_checkpoint = self.raw_version_map[op_name].checkpoints()[-1] + infos = last_checkpoint.version_desc().infos() + self.checkpoints_map[op_name] = infos + + def filter_updates(self, op_name, type=core.OpUpdateType.kInvalid, key=''): + updates = [] + if op_name in self.checkpoints_map: + for update in self.checkpoints_map[op_name]: + if (update.type() == type) or ( + type == core.OpUpdateType.kInvalid): + if OpUpdateInfoHelper(update.info()).verify_key_value(key): + updates.append(update.info()) + return updates