Skip to content

Commit

Permalink
add unittests, test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
Shixiaowei02 committed Nov 2, 2020
1 parent 551499f commit d2c2488
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 0 deletions.
67 changes: 67 additions & 0 deletions python/paddle/fluid/tests/unittests/test_op_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest

import paddle.utils as utils
import paddle.fluid as fluid


class OpLastCheckpointCheckerTest(unittest.TestCase):
def __init__(self, methodName='runTest'):
super(OpLastCheckpointCheckerTest, self).__init__(methodName)
self.checker = utils.OpLastCheckpointChecker()

def test_op_attr_info(self):
update_type = fluid.core.OpUpdateType.kNewAttr
info_list = self.checker.filter_updates('arg_max', update_type,
'flatten')
self.assertTrue(info_list)
self.assertTrue(info_list[0].name())
self.assertTrue(info_list[0].default_value() == False)
self.assertTrue(info_list[0].remark())

def test_op_input_output_info(self):
update_type = fluid.core.OpUpdateType.kNewInput
info_list = self.checker.filter_updates('roi_align', update_type,
'RoisNum')
self.assertTrue(info_list)
self.assertTrue(info_list[0].name())
self.assertTrue(info_list[0].remark())

def test_op_bug_fix_info(self):
update_type = fluid.core.OpUpdateType.kBugfixWithBehaviorChanged
info_list = self.checker.filter_updates('leaky_relu', update_type)
self.assertTrue(info_list)
self.assertTrue(info_list[0].remark())


class OpVersionTest(unittest.TestCase):
def __init__(self, methodName='runTest'):
super(OpVersionTest, self).__init__(methodName)
self.vmap = fluid.core.get_op_version_map()

def test_checkpoints(self):
version_id = self.vmap['arg_max'].version_id()
checkpoints = self.vmap['arg_max'].checkpoints()
self.assertTrue(version_id)
self.assertTrue(checkpoints)
self.assertTrue(checkpoints[0].note())
self.assertTrue(checkpoints[0].version_desc().infos())


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions python/paddle/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .profiler import get_profiler
from .deprecated import deprecated
from .lazy_import import try_import
from .op_version import OpLastCheckpointChecker
from .install_check import run_check
from ..fluid.framework import unique_name
from ..fluid.framework import load_op_library
Expand Down
70 changes: 70 additions & 0 deletions python/paddle/utils/op_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ..fluid import core

__all__ = ['OpLastCheckpointChecker']


def Singleton(cls):
_instance = {}

def _singleton(*args, **kargs):
if cls not in _instance:
_instance[cls] = cls(*args, **kargs)
return _instance[cls]

return _singleton


class OpUpdateInfoHelper(object):
def __init__(self, info):
self._info = info

def verify_key_value(self, name=''):
result = False
key_funcs = {
core.OpAttrInfo: 'name',
core.OpInputOutputInfo: 'name',
}
if name == '':
result = True
elif type(self._info) in key_funcs:
if getattr(self._info, key_funcs[type(self._info)])() == name:
result = True
return result


@Singleton
class OpLastCheckpointChecker(object):
def __init__(self):
self.raw_version_map = core.get_op_version_map()
self.checkpoints_map = {}
self._construct_map()

def _construct_map(self):
for op_name in self.raw_version_map:
last_checkpoint = self.raw_version_map[op_name].checkpoints()[-1]
infos = last_checkpoint.version_desc().infos()
self.checkpoints_map[op_name] = infos

def filter_updates(self, op_name, type=core.OpUpdateType.kInvalid, key=''):
updates = []
if op_name in self.checkpoints_map:
for update in self.checkpoints_map[op_name]:
if (update.type() == type) or (
type == core.OpUpdateType.kInvalid):
if OpUpdateInfoHelper(update.info()).verify_key_value(key):
updates.append(update.info())
return updates

0 comments on commit d2c2488

Please sign in to comment.