Skip to content

Commit 85f3afb

Browse files
committed
support Pytorch model in the weight compression algorithm
1 parent d285f47 commit 85f3afb

31 files changed

+1077
-815
lines changed

nncf/common/graph/layer_attributes.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -271,9 +271,11 @@ class ConvertDtypeLayerAttributes(BaseLayerAttributes):
271271

272272

273273
@dataclass
274-
class ParameterLayerAttributes(BaseLayerAttributes):
274+
class ConstantLayerAttributes(BaseLayerAttributes):
275275
"""
276-
:param name: Parameter name.
276+
:param name: Constant name.
277+
:param shape: Constant shape.
277278
"""
278279

279280
name: str
281+
shape: List[int]

nncf/common/graph/transformations/commands.py

-13
Original file line numberDiff line numberDiff line change
@@ -214,16 +214,3 @@ def __init__(self, command_type: TransformationType, target_point: TargetPoint):
214214
@property
215215
def target_point(self) -> TargetPoint:
216216
return self._target_point
217-
218-
def check_command_compatibility(self, command: "TransformationCommand") -> bool:
219-
return (
220-
isinstance(command, TransformationCommand)
221-
and self.type == command.type
222-
and self.target_point == command.target_point
223-
)
224-
225-
def union(self, other: "TransformationCommand") -> "TransformationCommand":
226-
raise NotImplementedError()
227-
228-
def __add__(self, other: "TransformationCommand") -> "TransformationCommand":
229-
return self.union(other)

nncf/experimental/tensor/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from nncf.experimental.tensor.enums import TensorBackendType as TensorBackendType
13-
from nncf.experimental.tensor.enums import TensorDataType as TensorDataType
14-
from nncf.experimental.tensor.enums import TensorDeviceType as TensorDeviceType
12+
from nncf.experimental.tensor.definitions import TensorBackendType as TensorBackendType
13+
from nncf.experimental.tensor.definitions import TensorDataType as TensorDataType
14+
from nncf.experimental.tensor.definitions import TensorDeviceType as TensorDeviceType
1515
from nncf.experimental.tensor.tensor import Tensor as Tensor
1616
from nncf.experimental.tensor.tensor import unwrap_tensor_data as unwrap_tensor_data

nncf/experimental/tensor/enums.py nncf/experimental/tensor/definitions.py

+17
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
from dataclasses import dataclass
1213
from enum import Enum
1314
from enum import auto
1415

@@ -32,6 +33,7 @@ class TensorDataType(Enum):
3233
float64 = auto()
3334
int8 = auto()
3435
uint8 = auto()
36+
int32 = auto()
3537

3638

3739
class TensorDeviceType(Enum):
@@ -41,3 +43,18 @@ class TensorDeviceType(Enum):
4143

4244
CPU = auto()
4345
GPU = auto()
46+
47+
48+
@dataclass
49+
class TypeInfo:
50+
"""
51+
The class represents the numerical properties of a floating point types.
52+
53+
:param eps: The smallest representable number such that 1.0 + eps != 1.0.
54+
:param max: The largest representable number.
55+
:param min: The smallest representable number (typically -max).
56+
"""
57+
58+
eps: float
59+
max: float
60+
min: float

nncf/experimental/tensor/functions.py

+48-5
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,14 @@
1010
# limitations under the License.
1111

1212
import functools
13-
from typing import Callable, List, Optional, Tuple, TypeVar, Union
13+
from typing import Any, Callable, List, Optional, Tuple, Union
1414

15-
from nncf.experimental.tensor.enums import TensorDataType
16-
from nncf.experimental.tensor.enums import TensorDeviceType
15+
from nncf.experimental.tensor.definitions import TensorDataType
16+
from nncf.experimental.tensor.definitions import TensorDeviceType
17+
from nncf.experimental.tensor.definitions import TypeInfo
1718
from nncf.experimental.tensor.tensor import Tensor
1819
from nncf.experimental.tensor.tensor import unwrap_tensor_data
1920

20-
TypeInfo = TypeVar("TypeInfo")
21-
2221

2322
def _tensor_guard(func: callable):
2423
"""
@@ -442,6 +441,50 @@ def finfo(a: Tensor) -> TypeInfo:
442441
return finfo(a.data)
443442

444443

444+
@functools.singledispatch
445+
@_tensor_guard
446+
def clip(a: Tensor, a_min: Union[Tensor, float], a_max: Union[Tensor, float]) -> Tensor:
447+
"""
448+
Clips all elements in input into the range [ a_min, a_max ]
449+
450+
:param a: Tensor.
451+
:param a_min: A lower-bound of the range to be clamped to.
452+
:param a_max: An upper-bound of the range to be clamped to.
453+
:return: A clipped tensor with the elements of a, but where values < a_min are replaced with a_min,
454+
and those > a_max with a_max.
455+
"""
456+
return Tensor(clip(a.data, unwrap_tensor_data(a_min), unwrap_tensor_data(a_max)))
457+
458+
459+
@functools.singledispatch
460+
@_tensor_guard
461+
def as_tensor_like(a: Tensor, data: Any) -> Tensor:
462+
"""
463+
Converts the data into a tensor with the same data representation and hosted on the same device
464+
as the given tensor.
465+
466+
:param a: A tensor for defining the data representation and the host device of the output tensor.
467+
:param data: Initial data for the tensor. Can be a list, tuple, NumPy ndarray, scalar, and other types.
468+
:return: A tensor with the same data representation and hosted on the same device as a,
469+
and which has been initialized with data.
470+
"""
471+
return Tensor(as_tensor_like(a.data, data))
472+
473+
474+
@functools.singledispatch
475+
@_tensor_guard
476+
def item(a: Tensor) -> Union[int, float, bool]:
477+
"""
478+
Returns the value of this tensor as a standard Python number. This only works for tensors with one element.
479+
480+
:param a: Tensor.
481+
:return: The value of this tensor as a standard Python number
482+
"""
483+
if isinstance(a.data, (int, float, bool)):
484+
return a.data
485+
return item(a.data)
486+
487+
445488
def _dispatch_list(fn: "functools._SingleDispatchCallable", tensor_list: List[Tensor], *args, **kwargs):
446489
"""
447490
Dispatches the function to the type of the wrapped data of the first element in tensor_list.

nncf/experimental/tensor/numpy_functions.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,22 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Callable, List, Optional, Tuple, Union
12+
from typing import Any, Callable, List, Optional, Tuple, Union
1313

1414
import numpy as np
1515

1616
from nncf.experimental.tensor import functions as fns
17-
from nncf.experimental.tensor.enums import TensorDataType
18-
from nncf.experimental.tensor.enums import TensorDeviceType
17+
from nncf.experimental.tensor.definitions import TensorDataType
18+
from nncf.experimental.tensor.definitions import TensorDeviceType
19+
from nncf.experimental.tensor.definitions import TypeInfo
1920

2021
DTYPE_MAP = {
2122
TensorDataType.float16: np.dtype(np.float16),
2223
TensorDataType.float32: np.dtype(np.float32),
2324
TensorDataType.float64: np.dtype(np.float64),
2425
TensorDataType.int8: np.dtype(np.int8),
2526
TensorDataType.uint8: np.dtype(np.uint8),
27+
TensorDataType.int32: np.dtype(np.int32),
2628
}
2729

2830
DTYPE_MAP_REV = {v: k for k, v in DTYPE_MAP.items()}
@@ -209,5 +211,25 @@ def _(
209211

210212

211213
@_register_numpy_types(fns.finfo)
212-
def _(a: np.ndarray) -> np.finfo:
213-
return np.finfo(a.dtype)
214+
def _(a: np.ndarray) -> TypeInfo:
215+
ti = np.finfo(a.dtype)
216+
return TypeInfo(ti.eps, ti.max, ti.min)
217+
218+
219+
@_register_numpy_types(fns.clip)
220+
def _(
221+
a: Union[np.ndarray, np.generic],
222+
a_min: Union[np.ndarray, np.generic, float],
223+
a_max: Union[np.ndarray, np.generic, float],
224+
) -> Union[np.ndarray, np.generic]:
225+
return np.clip(a, a_min, a_max)
226+
227+
228+
@_register_numpy_types(fns.as_tensor_like)
229+
def _(a: Union[np.ndarray, np.generic], data: Any) -> Union[np.ndarray, np.generic]:
230+
return np.array(data)
231+
232+
233+
@_register_numpy_types(fns.item)
234+
def _(a: Union[np.ndarray, np.generic]) -> Union[int, float, bool]:
235+
return a.item()

nncf/experimental/tensor/tensor.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import operator
1414
from typing import Any, Optional, Tuple, TypeVar, Union
1515

16-
from nncf.experimental.tensor.enums import TensorDataType
17-
from nncf.experimental.tensor.enums import TensorDeviceType
16+
from nncf.experimental.tensor.definitions import TensorDataType
17+
from nncf.experimental.tensor.definitions import TensorDeviceType
1818

1919
TTensor = TypeVar("TTensor")
2020

@@ -44,7 +44,7 @@ def device(self) -> TensorDeviceType:
4444
return _call_function("device", self)
4545

4646
@property
47-
def dtype(self) -> TensorDeviceType:
47+
def dtype(self) -> TensorDataType:
4848
return _call_function("dtype", self)
4949

5050
def __bool__(self) -> bool:
@@ -146,6 +146,9 @@ def astype(self, dtype: TensorDataType) -> Tensor:
146146
def reshape(self, shape: Tuple[int, ...]) -> Tensor:
147147
return _call_function("reshape", self, shape)
148148

149+
def item(self) -> float:
150+
return _call_function("item", self)
151+
149152

150153
def _call_function(func_name: str, *args):
151154
"""

nncf/experimental/tensor/torch_functions.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,22 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Callable, List, Optional, Tuple, Union
12+
from typing import Any, Callable, List, Optional, Tuple, Union
1313

1414
import torch
1515

1616
from nncf.experimental.tensor import TensorDataType
1717
from nncf.experimental.tensor import TensorDeviceType
1818
from nncf.experimental.tensor import functions as fns
19+
from nncf.experimental.tensor.definitions import TypeInfo
1920

2021
DTYPE_MAP = {
2122
TensorDataType.float16: torch.float16,
2223
TensorDataType.float32: torch.float32,
2324
TensorDataType.float64: torch.float64,
2425
TensorDataType.int8: torch.int8,
2526
TensorDataType.uint8: torch.uint8,
27+
TensorDataType.int32: torch.int32,
2628
}
2729

2830
DTYPE_MAP_REV = {v: k for k, v in DTYPE_MAP.items()}
@@ -200,3 +202,24 @@ def _(a: torch.Tensor, b: Union[torch.Tensor, float], operator_fn: Callable) ->
200202
@fns._binary_reverse_op_nowarn.register(torch.Tensor)
201203
def _(a: torch.Tensor, b: Union[torch.Tensor, float], operator_fn: Callable) -> torch.Tensor:
202204
return operator_fn(b, a)
205+
206+
207+
@fns.clip.register(torch.Tensor)
208+
def _(a: torch.Tensor, a_min: Union[torch.Tensor, float], a_max: Union[torch.Tensor, float]) -> torch.Tensor:
209+
return torch.clip(a, a_min, a_max)
210+
211+
212+
@fns.finfo.register(torch.Tensor)
213+
def _(a: torch.Tensor) -> TypeInfo:
214+
ti = torch.finfo(a.dtype)
215+
return TypeInfo(ti.eps, ti.max, ti.min)
216+
217+
218+
@fns.as_tensor_like.register(torch.Tensor)
219+
def _(a: torch.Tensor, data: Any) -> torch.Tensor:
220+
return torch.as_tensor(data, device=a.device)
221+
222+
223+
@fns.item.register(torch.Tensor)
224+
def _(a: torch.Tensor) -> Union[int, float, bool]:
225+
return a.item()

nncf/onnx/graph/transformations/commands.py

+2-27
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,6 @@ def __init__(self, target_point: ONNXTargetPoint, input_edges_mapping: Dict[str,
5656
# need to keep the mapping NNCF input nodes to the following ONNX nodes.
5757
self.input_edges_mapping = input_edges_mapping
5858

59-
def union(self, other: "TransformationCommand") -> "TransformationCommand":
60-
# Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand
61-
raise NotImplementedError()
62-
6359

6460
class ONNXQuantizerInsertionCommand(ONNXInsertionCommand):
6561
def __init__(
@@ -71,15 +67,10 @@ def __init__(
7167
super().__init__(target_point, nncf_input_node_next_onnx_nodes)
7268
self.quantizer_parameters = quantizer_parameters
7369

74-
def union(self, other: "TransformationCommand") -> "TransformationCommand":
75-
# Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand
76-
raise NotImplementedError()
77-
7870

7971
class ONNXOutputInsertionCommand(ONNXInsertionCommand):
80-
def union(self, other: "TransformationCommand") -> "TransformationCommand":
81-
# Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand
82-
raise NotImplementedError()
72+
def __init__(self, target_point: ONNXTargetPoint, input_edges_mapping: Dict[str, Tuple[str, int]]):
73+
super().__init__(TransformationType.INSERT, target_point, input_edges_mapping)
8374

8475

8576
class ONNXBiasCorrectionCommand(TransformationCommand):
@@ -95,10 +86,6 @@ def __init__(self, target_point: ONNXTargetPoint, bias_value: np.ndarray):
9586
super().__init__(TransformationType.CHANGE, target_point)
9687
self.bias_value = bias_value
9788

98-
def union(self, other: "TransformationCommand") -> "TransformationCommand":
99-
# Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand
100-
raise NotImplementedError()
101-
10289

10390
class ONNXModelExtractionCommand(Command):
10491
"""
@@ -114,10 +101,6 @@ def __init__(self, inputs: List[str], outputs: List[str]):
114101
self.inputs = inputs
115102
self.outputs = outputs
116103

117-
def union(self, other: "Command") -> "Command":
118-
# Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand
119-
raise NotImplementedError()
120-
121104

122105
class ONNXQDQNodeRemovingCommand(TransformationCommand):
123106
"""
@@ -130,10 +113,6 @@ def __init__(self, target_point: ONNXTargetPoint):
130113
"""
131114
super().__init__(TransformationType.REMOVE, target_point)
132115

133-
def union(self, other: "TransformationCommand") -> "TransformationCommand":
134-
# Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand
135-
raise NotImplementedError()
136-
137116

138117
class ONNXNullBiasInsertionCommand(TransformationCommand):
139118
"""
@@ -145,7 +124,3 @@ def __init__(self, target_point: ONNXTargetPoint):
145124
:param target_point: The TargetPoint instance for the insertion that contains layer's information.
146125
"""
147126
super().__init__(TransformationType.INSERT, target_point)
148-
149-
def union(self, other: "TransformationCommand") -> "TransformationCommand":
150-
# Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand
151-
raise NotImplementedError()

nncf/onnx/quantization/quantizer_parameters.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616

17+
from nncf.experimental.tensor import functions as fns
1718
from nncf.quantization.fake_quantize import FakeQuantizeParameters
1819
from nncf.quantization.fake_quantize import calculate_scale_zero_point
1920

@@ -54,17 +55,17 @@ def convert_fq_params_to_onnx_params(
5455
if levels not in [255, 256]:
5556
raise ValueError("Can only export to INT8/UIN8 256-level ONNX Quantize/Dequantize pairs.")
5657

57-
input_low, input_high = parameters.input_low.data, parameters.input_high.data
58-
output_low, output_high = parameters.output_low.data, parameters.output_high.data
59-
if not np.allclose(input_high, output_high) or not np.allclose(input_low, output_low):
58+
input_low, input_high = parameters.input_low, parameters.input_high
59+
output_low, output_high = parameters.output_low, parameters.output_high
60+
if not fns.allclose(input_high, output_high) or not fns.allclose(input_low, output_low):
6061
raise ValueError(
6162
"ONNX Quantize/Dequantize pairs only support input_high == output_high and input_low == output_low."
6263
)
6364

6465
level_low, level_high = get_level_low_level_high(tensor_type)
6566
narrow_range = levels == 2**num_bits - 1
6667
scale, zero_point = calculate_scale_zero_point(input_low, input_high, level_low, level_high, narrow_range)
67-
return ONNXQuantizerLayerParameters(scale, zero_point, tensor_type, axis)
68+
return ONNXQuantizerLayerParameters(scale.data, zero_point.data, tensor_type, axis)
6869

6970

7071
def get_level_low_level_high(tensor_type: np.dtype) -> Tuple[int, int]:

0 commit comments

Comments
 (0)