-
Notifications
You must be signed in to change notification settings - Fork 247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FX] Support weight quantization for operations where weight_port_id
!= 1
#3334
base: develop
Are you sure you want to change the base?
Conversation
if is_weight: | ||
channel_axes = self._backend_entity.get_weight_quantization_axes(node, target_point, len(shape)) | ||
else: | ||
channel_axes = (1,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not required, it can be reverted back to the old code.
|
||
# Weight statistics is constant, so only one collection is enough. | ||
range_estimator_params = self._get_range_estimator_parameters(target_point, qconfig) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also unnecessary
) | ||
channel_idx = channel_axes[0] if channel_axes else 0 | ||
|
||
if is_weights and not channel_axes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if is_weights and not channel_axes: | |
if not len(channel_axes): |
to cover the case of vector weights which are being quantized per channel
scale_shape = tuple( | ||
get_scale_shape(input_shape, is_weights=is_weights, per_channel=per_channel, channel_idx=channel_idx) | ||
) | ||
channel_idx = channel_axes[0] if channel_axes else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
channel_idx = channel_axes[0] if channel_axes else 0 |
Since channel axes is already being checked and handled in the if-else
block below. channel_axes[0]
can directly be passed to channel_idx
parameter of get_scale_shape
Changes
Updated the FX backend’s
_get_input_scale_shape
to use the FX insertion point shape and, when available, the actual weight tensor’s shape to compute the per‑channel scale shape.Adjusted statistics collector in
_get_stat_collector
so that the reduction and aggregation axes are derived using the same channel axes as used for scale shape computation.Related tickets
Issue #3206
Tests
All tests run successfully