Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FX] Support weight quantization for operations where weight_port_id != 1 #3334

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,17 +458,16 @@ def _get_stat_collector(
is_weight = target_point.is_weight_target_point()
node = graph.get_node_by_name(target_point.target_node_name)
shape = self._backend_entity.get_target_point_shape(graph, node, target_point)
range_estimator_params = self._get_range_estimator_parameters(target_point, qconfig)


channel_axes = ()
if qconfig.per_channel:
channel_axes = (
self._backend_entity.get_weight_quantization_axes(node, target_point, len(shape)) if is_weight else (1,)
)
if is_weight:
channel_axes = self._backend_entity.get_weight_quantization_axes(node, target_point, len(shape))
else:
channel_axes = (1,)
Comment on lines +464 to +467
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not required, it can be reverted back to the old code.


# Weight statistics is constant, so only one collection is enough.
range_estimator_params = self._get_range_estimator_parameters(target_point, qconfig)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also unnecessary

num_samples = self._subset_size if not is_weight else 1

batchwise_statistics = batchwise_statistics and not is_weight

collector_params = RangeInitCollectorParams(
Expand Down
27 changes: 18 additions & 9 deletions nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.hardware.config import PTHWConfig
from nncf.torch.model_graph_manager import get_weight_tensor_port_ids
from nncf.torch.model_graph_manager import get_weight_channel_axes
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
from nncf.torch.quantization.layers import QUANTIZATION_MODULES
Expand Down Expand Up @@ -149,8 +150,7 @@ def get_target_point_shape(nncf_graph: PTNNCFGraph, node: NNCFNode, target_point

@staticmethod
def get_weight_quantization_axes(node: NNCFNode, target_point: PTTargetPoint, ndims: int) -> Tuple[int]:
# TODO(dlyakhov): support transpose conv and other cases
return (0,)
return get_weight_channel_axes(node.metatype, ndims, target_point.input_port_id)

@staticmethod
def get_weight_tensor_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[Optional[int]]:
Expand All @@ -177,16 +177,25 @@ def _get_input_scale_shape(
nncf_graph: NNCFGraph, target_point: PTTargetPoint, per_channel: bool
) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
is_weights = target_point.is_weight_target_point()
input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point)

if is_weights:
# TODO(dlyakhov): support transpose conv/ make channel_idx common
channel_idx = 0
node = nncf_graph.get_node_by_name(target_point.target_node_name)
channel_axes = get_weight_channel_axes(node.metatype, len(input_shape), target_point.input_port_id)
else:
channel_idx = 1 # channel dim for activations
channel_axes = [1]

input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point)
scale_shape = tuple(
get_scale_shape(input_shape, is_weights=is_weights, per_channel=per_channel, channel_idx=channel_idx)
)
channel_idx = channel_axes[0] if channel_axes else 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
channel_idx = channel_axes[0] if channel_axes else 0

Since channel axes is already being checked and handled in the if-else block below. channel_axes[0] can directly be passed to channel_idx parameter of get_scale_shape


if is_weights and not channel_axes:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if is_weights and not channel_axes:
if not len(channel_axes):

to cover the case of vector weights which are being quantized per channel

scale_shape = (1,)
else:
scale_shape = tuple(get_scale_shape(
input_shape,
is_weights=is_weights,
per_channel=per_channel,
channel_idx=channel_idx
))

return input_shape, scale_shape, channel_idx

Expand Down
Loading