-
Notifications
You must be signed in to change notification settings - Fork 247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FX] Support weight quantization for operations where weight_port_id
!= 1
#3334
base: develop
Are you sure you want to change the base?
Changes from all commits
b8203a5
474e6b7
7319f9f
3308b40
0ee4f8b
74f677a
3565ba3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -458,17 +458,16 @@ def _get_stat_collector( | |
is_weight = target_point.is_weight_target_point() | ||
node = graph.get_node_by_name(target_point.target_node_name) | ||
shape = self._backend_entity.get_target_point_shape(graph, node, target_point) | ||
range_estimator_params = self._get_range_estimator_parameters(target_point, qconfig) | ||
|
||
|
||
channel_axes = () | ||
if qconfig.per_channel: | ||
channel_axes = ( | ||
self._backend_entity.get_weight_quantization_axes(node, target_point, len(shape)) if is_weight else (1,) | ||
) | ||
if is_weight: | ||
channel_axes = self._backend_entity.get_weight_quantization_axes(node, target_point, len(shape)) | ||
else: | ||
channel_axes = (1,) | ||
|
||
# Weight statistics is constant, so only one collection is enough. | ||
range_estimator_params = self._get_range_estimator_parameters(target_point, qconfig) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is also unnecessary |
||
num_samples = self._subset_size if not is_weight else 1 | ||
|
||
batchwise_statistics = batchwise_statistics and not is_weight | ||
|
||
collector_params = RangeInitCollectorParams( | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -40,6 +40,7 @@ | |||||
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand | ||||||
from nncf.torch.hardware.config import PTHWConfig | ||||||
from nncf.torch.model_graph_manager import get_weight_tensor_port_ids | ||||||
from nncf.torch.model_graph_manager import get_weight_channel_axes | ||||||
from nncf.torch.nncf_network import NNCFNetwork | ||||||
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT | ||||||
from nncf.torch.quantization.layers import QUANTIZATION_MODULES | ||||||
|
@@ -149,8 +150,7 @@ def get_target_point_shape(nncf_graph: PTNNCFGraph, node: NNCFNode, target_point | |||||
|
||||||
@staticmethod | ||||||
def get_weight_quantization_axes(node: NNCFNode, target_point: PTTargetPoint, ndims: int) -> Tuple[int]: | ||||||
# TODO(dlyakhov): support transpose conv and other cases | ||||||
return (0,) | ||||||
return get_weight_channel_axes(node.metatype, ndims, target_point.input_port_id) | ||||||
|
||||||
@staticmethod | ||||||
def get_weight_tensor_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[Optional[int]]: | ||||||
|
@@ -177,16 +177,25 @@ def _get_input_scale_shape( | |||||
nncf_graph: NNCFGraph, target_point: PTTargetPoint, per_channel: bool | ||||||
) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: | ||||||
is_weights = target_point.is_weight_target_point() | ||||||
input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point) | ||||||
|
||||||
if is_weights: | ||||||
# TODO(dlyakhov): support transpose conv/ make channel_idx common | ||||||
channel_idx = 0 | ||||||
node = nncf_graph.get_node_by_name(target_point.target_node_name) | ||||||
channel_axes = get_weight_channel_axes(node.metatype, len(input_shape), target_point.input_port_id) | ||||||
else: | ||||||
channel_idx = 1 # channel dim for activations | ||||||
channel_axes = [1] | ||||||
|
||||||
input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point) | ||||||
scale_shape = tuple( | ||||||
get_scale_shape(input_shape, is_weights=is_weights, per_channel=per_channel, channel_idx=channel_idx) | ||||||
) | ||||||
channel_idx = channel_axes[0] if channel_axes else 0 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Since channel axes is already being checked and handled in the |
||||||
|
||||||
if is_weights and not channel_axes: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
to cover the case of vector weights which are being quantized per channel |
||||||
scale_shape = (1,) | ||||||
else: | ||||||
scale_shape = tuple(get_scale_shape( | ||||||
input_shape, | ||||||
is_weights=is_weights, | ||||||
per_channel=per_channel, | ||||||
channel_idx=channel_idx | ||||||
)) | ||||||
|
||||||
return input_shape, scale_shape, channel_idx | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not required, it can be reverted back to the old code.