Skip to content

Commit 9f69471

Browse files
[mypy] extend mypy check (#3154)
### Changes Added some files to mypy check and fix errors: - nncf/quantization/passes.py - nncf/quantization/advanced_parameters.py - nncf/quantization/range_estimator.py - nncf/quantization/telemetry_extractors.py
1 parent 909ce0a commit 9f69471

File tree

5 files changed

+49
-29
lines changed

5 files changed

+49
-29
lines changed

nncf/common/graph/graph.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,21 @@
1111
import pathlib
1212
from collections import defaultdict
1313
from copy import deepcopy
14-
from typing import Any, Callable, Dict, Generator, KeysView, List, Optional, Tuple, Type, Union, ValuesView, cast
14+
from typing import (
15+
Any,
16+
Callable,
17+
Collection,
18+
Dict,
19+
Generator,
20+
KeysView,
21+
List,
22+
Optional,
23+
Tuple,
24+
Type,
25+
Union,
26+
ValuesView,
27+
cast,
28+
)
1529

1630
import networkx as nx # type:ignore
1731
import networkx.algorithms.isomorphism as iso # type:ignore
@@ -245,7 +259,7 @@ def get_nodes_by_types(self, type_list: List[str]) -> List[NNCFNode]:
245259
all_nodes_of_type.append(nncf_node)
246260
return all_nodes_of_type
247261

248-
def get_nodes_by_metatypes(self, metatype_list: List[Type[OperatorMetatype]]) -> List[NNCFNode]:
262+
def get_nodes_by_metatypes(self, metatype_list: Collection[Type[OperatorMetatype]]) -> List[NNCFNode]:
249263
"""
250264
Return a list of nodes with provided metatypes.
251265
@@ -766,7 +780,7 @@ def get_all_edges(self) -> Generator[NNCFGraphEdge, None, None]:
766780
for nx_edge in self._nx_graph.in_edges:
767781
yield self.get_edge(self.get_node_by_key(nx_edge[0]), self.get_node_by_key(nx_edge[1]))
768782

769-
def remove_nodes_from(self, nodes: List[NNCFNode]) -> None:
783+
def remove_nodes_from(self, nodes: Collection[NNCFNode]) -> None:
770784
"""
771785
Removes nodes from the current NNCFGraph instance.
772786
We use the remove_node method here because remove_nodes_from uses a silent fail instead of an exception.

nncf/quantization/advanced_parameters.py

+16-14
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ class AdvancedQuantizationParameters:
182182
183183
:param overflow_fix: This option controls whether to apply the overflow issue fix
184184
for the 8-bit quantization.
185-
:type overflow_fix: nncf.quantization.advanced_parameters.OverflowFix
185+
:type overflow_fix: Optional[nncf.quantization.advanced_parameters.OverflowFix]
186186
:param quantize_outputs: Whether to insert additional quantizers right before each
187187
of the model outputs.
188188
:type quantize_outputs: bool
@@ -232,16 +232,16 @@ class AdvancedQuantizationParameters:
232232
"""
233233

234234
# General parameters
235-
overflow_fix: OverflowFix = None
235+
overflow_fix: Optional[OverflowFix] = None
236236
quantize_outputs: bool = False
237237
inplace_statistics: bool = True
238238
disable_channel_alignment: bool = True
239239
disable_bias_correction: bool = False
240240
batchwise_statistics: Optional[bool] = None
241241

242242
# Advanced Quantization parameters
243-
activations_quantization_params: Union[QuantizationParameters, FP8QuantizationParameters] = None
244-
weights_quantization_params: Union[QuantizationParameters, FP8QuantizationParameters] = None
243+
activations_quantization_params: Optional[Union[QuantizationParameters, FP8QuantizationParameters]] = None
244+
weights_quantization_params: Optional[Union[QuantizationParameters, FP8QuantizationParameters]] = None
245245
quantizer_propagation_rule: QuantizerPropagationRule = QuantizerPropagationRule.MERGE_ALL_IN_ONE
246246

247247
# Range estimator parameters
@@ -254,7 +254,7 @@ class AdvancedQuantizationParameters:
254254
# Advanced SmoothQuant algorithm parameters
255255
smooth_quant_alphas: AdvancedSmoothQuantParameters = field(default_factory=AdvancedSmoothQuantParameters)
256256
# Deprecated parameter
257-
smooth_quant_alpha: float = None
257+
smooth_quant_alpha: Optional[float] = None
258258

259259
# Backend specific parameters
260260
backend_params: Dict[str, Any] = field(default_factory=dict)
@@ -460,14 +460,14 @@ def convert_to_dict_recursively(params: Any) -> Dict[str, Any]:
460460
return result
461461

462462

463-
def convert_quantization_parameters_to_dict(params: QuantizationParameters) -> Dict[str, Any]:
463+
def convert_quantization_parameters_to_dict(params: Optional[QuantizationParameters]) -> Dict[str, Any]:
464464
"""
465465
Converts quantization parameters to the dict in the legacy format
466466
467467
:param params: Quantization parameters
468468
:return: Quantization parameters as dict in the legacy format
469469
"""
470-
result = {}
470+
result: Dict[str, Any] = {}
471471
if params is not None:
472472
if params.num_bits is not None:
473473
result["bits"] = params.num_bits
@@ -492,7 +492,7 @@ def convert_range_estimator_parameters_to_dict(params: RangeEstimatorParameters)
492492
if params.min.clipping_value is not None or params.max.clipping_value is not None:
493493
raise nncf.ParameterNotSupportedError("clipping_value parameter is not supported in the legacy format")
494494

495-
result = {}
495+
result: Dict[str, Any] = {}
496496
if (
497497
params.min.statistics_type == StatisticsType.MIN
498498
and params.min.aggregator_type == AggregatorType.MIN
@@ -551,13 +551,15 @@ def apply_advanced_parameters_to_config(
551551
initializer["batchnorm_adaptation"] = {"num_bn_adaptation_samples": 0}
552552
config["initializer"] = initializer
553553

554-
activations_config = convert_quantization_parameters_to_dict(params.activations_quantization_params)
555-
if activations_config:
556-
config["activations"] = activations_config
554+
if isinstance(params.activations_quantization_params, QuantizationParameters):
555+
activations_config = convert_quantization_parameters_to_dict(params.activations_quantization_params)
556+
if activations_config:
557+
config["activations"] = activations_config
557558

558-
weights_config = convert_quantization_parameters_to_dict(params.weights_quantization_params)
559-
if weights_config:
560-
config["weights"] = weights_config
559+
if isinstance(params.weights_quantization_params, QuantizationParameters):
560+
weights_config = convert_quantization_parameters_to_dict(params.weights_quantization_params)
561+
if weights_config:
562+
config["weights"] = weights_config
561563

562564
activations_init_range_config = convert_range_estimator_parameters_to_dict(
563565
params.activations_range_estimator_params

nncf/quantization/passes.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# limitations under the License.
1111

1212
import collections
13-
from typing import List, TypeVar
13+
from typing import Deque, List, Type, TypeVar
1414

1515
from nncf.common.graph.graph import NNCFGraph
1616
from nncf.common.graph.graph import NNCFNode
@@ -23,9 +23,9 @@
2323
def transform_to_inference_graph(
2424
nncf_graph: NNCFGraph,
2525
input_nodes: List[NNCFNode],
26-
shapeof_metatypes: List[OperatorMetatype],
27-
dropout_metatypes: List[OperatorMetatype],
28-
preserved_metatypes: List[OperatorMetatype],
26+
shapeof_metatypes: List[Type[OperatorMetatype]],
27+
dropout_metatypes: List[Type[OperatorMetatype]],
28+
preserved_metatypes: List[Type[OperatorMetatype]],
2929
) -> NNCFGraph:
3030
"""
3131
This method contains inplace pipeline of the passes that uses to provide inference graph without constant flows.
@@ -49,7 +49,7 @@ def transform_to_inference_graph(
4949

5050
def find_shapeof_subgraphs(
5151
nncf_graph: NNCFGraph,
52-
shapeof_metatypes: List[OperatorMetatype],
52+
shapeof_metatypes: List[Type[OperatorMetatype]],
5353
input_nodes: List[NNCFNode],
5454
) -> List[NNCFNode]:
5555
"""
@@ -80,7 +80,7 @@ def find_shapeof_subgraphs(
8080
for shape_of_node in shape_of_nodes:
8181
shapeof_subgraphs.add(shape_of_node)
8282

83-
shape_of_queue = collections.deque()
83+
shape_of_queue: Deque[NNCFNode] = collections.deque()
8484
shape_of_queue.extend(nncf_graph.get_next_nodes(shape_of_node))
8585
while shape_of_queue:
8686
node = shape_of_queue.pop()
@@ -97,7 +97,7 @@ def find_shapeof_subgraphs(
9797
def find_preserved_nodes(
9898
graph: NNCFGraph,
9999
shapeof_subgraphs: List[NNCFNode],
100-
preserved_metatypes: List[OperatorMetatype],
100+
preserved_metatypes: List[Type[OperatorMetatype]],
101101
) -> List[NNCFNode]:
102102
"""
103103
:param graph: The input graph to be analyzed.
@@ -129,7 +129,7 @@ def find_preserved_nodes(
129129

130130
def remove_nodes_and_reconnect_graph(
131131
nncf_graph: NNCFGraph,
132-
metatypes: List[OperatorMetatype],
132+
metatypes: List[Type[OperatorMetatype]],
133133
) -> NNCFGraph:
134134
"""
135135
Removes nodes with metatypes specified by `metatypes` parameter from

nncf/quantization/quantize_model.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def warning_model_no_batchwise_support(
6767
:param graph: Model's NNCFGraph.
6868
:param advanced_quantization_parameters: AdvancedQuantizationParameters.
6969
:param model_type: Model type algorithm option.
70-
:param no_batchwise_support_metatypes: Meatypes having no batchwise statistics support.
70+
:param no_batchwise_support_metatypes: Metatypes having no batchwise statistics support.
7171
"""
7272
if is_model_no_batchwise_support(
7373
graph, advanced_quantization_parameters, model_type, no_batchwise_support_metatypes
@@ -80,16 +80,16 @@ def is_model_no_batchwise_support(
8080
advanced_quantization_parameters: Optional[AdvancedQuantizationParameters],
8181
model_type: ModelType,
8282
no_batchwise_support_metatypes: List[OperatorMetatype],
83-
) -> None:
83+
) -> bool:
8484
"""
8585
Returns True if batchwise statistics could lead to a significant accuracy drop.
8686
8787
:param graph: Model's NNCFGraph.
8888
:param advanced_quantization_parameters: AdvancedQuantizationParameters.
8989
:param model_type: Model type algorithm option.
90-
:param no_batchwise_support_metatypes: Meatypes having no batchwise statistics support.
90+
:param no_batchwise_support_metatypes: Metatypes having no batchwise statistics support.
9191
"""
92-
return (
92+
return bool(
9393
advanced_quantization_parameters
9494
and advanced_quantization_parameters.batchwise_statistics
9595
and (graph.get_nodes_by_metatypes(no_batchwise_support_metatypes) or model_type == ModelType.TRANSFORMER)

pyproject.toml

+4
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ files = [
100100
"nncf/common/utils/",
101101
"nncf/common/tensor_statistics",
102102
"nncf/experimental/torch2",
103+
"nncf/quantization/passes.py",
104+
"nncf/quantization/advanced_parameters.py",
105+
"nncf/quantization/range_estimator.py",
106+
"nncf/quantization/telemetry_extractors.py",
103107
"nncf/telemetry/",
104108
]
105109

0 commit comments

Comments
 (0)