forked from qiskit-community/qiskit-algorithms
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathalgorithms_test_case.py
88 lines (70 loc) · 2.86 KB
/
algorithms_test_case.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# This code is part of a Qiskit project.
#
# (C) Copyright IBM 2018, 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""Algorithms Test Case"""
from typing import Optional
from abc import ABC
import warnings
import inspect
import logging
import os
import unittest
import time
from qiskit_algorithms.utils import algorithm_globals
# disable deprecation warnings that can cause log output overflow
# pylint: disable=unused-argument
def _noop(*args, **kargs):
pass
# disable warning messages
# warnings.warn = _noop
class QiskitAlgorithmsTestCase(unittest.TestCase, ABC):
"""Optimization Test Case"""
moduleName = None
log = None
def setUp(self) -> None:
warnings.filterwarnings("default", category=DeprecationWarning)
self._started_at = time.time()
self._class_location = __file__
def tearDown(self) -> None:
algorithm_globals.random_seed = None
elapsed = time.time() - self._started_at
if elapsed > 5.0:
print(f"({round(elapsed, 2):.2f}s)", flush=True)
@classmethod
def setUpClass(cls) -> None:
cls.moduleName = os.path.splitext(inspect.getfile(cls))[0]
cls.log = logging.getLogger(cls.__name__)
# Set logging to file and stdout if the LOG_LEVEL environment variable
# is set.
if os.getenv("LOG_LEVEL"):
# Set up formatter.
log_fmt = f"{cls.__name__}.%(funcName)s:%(levelname)s:%(asctime)s:" " %(message)s"
formatter = logging.Formatter(log_fmt)
# Set up the file handler.
log_file_name = f"{cls.moduleName}.log"
file_handler = logging.FileHandler(log_file_name)
file_handler.setFormatter(formatter)
cls.log.addHandler(file_handler)
# Set the logging level from the environment variable, defaulting
# to INFO if it is not a valid level.
level = logging._nameToLevel.get(os.getenv("LOG_LEVEL"), logging.INFO)
cls.log.setLevel(level)
def get_resource_path(self, filename: str, path: Optional[str] = None) -> str:
"""Get the absolute path to a resource.
Args:
filename: filename or relative path to the resource.
path: path used as relative to the filename.
Returns:
str: the absolute path to the resource.
"""
root = os.path.dirname(self._class_location)
path = root if path is None else os.path.join(root, path)
return os.path.normpath(os.path.join(path, filename))