Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strip for LoRA modules #3331

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ci/cspell_dict.txt
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ logit
loglikelihoods
lstmsequence
lstsq
lspec
lyalyushkin
mapillary
maskrcnn
Expand Down
159 changes: 157 additions & 2 deletions nncf/torch/quantization/strip.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,35 @@
# limitations under the License.


from typing import List

import numpy as np
import torch
from torch.quantization.fake_quantize import FakeQuantize

import nncf
from nncf.common.graph.transformations.commands import Command
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.torch.dynamic_graph.scope import Scope
from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.model_graph_manager import get_const_node
from nncf.torch.model_graph_manager import get_module_by_name
from nncf.torch.model_graph_manager import split_const_name
from nncf.torch.model_transformer import PTModelTransformer
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.layers import AsymmetricLoraQuantizer
from nncf.torch.quantization.layers import AsymmetricQuantizer
from nncf.torch.quantization.layers import BaseQuantizer
from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor
from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor
from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor
from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor
from nncf.torch.quantization.layers import SymmetricLoraQuantizer
from nncf.torch.quantization.layers import SymmetricQuantizer
from nncf.torch.quantization.quantize_functions import TuneRange

SUPPORTED_NUM_BITS_FOR_STRIP_MODEL = [8]

Expand Down Expand Up @@ -171,6 +190,142 @@ def strip_quantized_model(model: NNCFNetwork):
:param model: Compressed model.
:return: The modified NNCF network.
"""
model = replace_quantizer_to_torch_native_module(model)
model = remove_disabled_quantizers(model)
model_layout = model.nncf.transformation_layout()
transformations = model_layout.transformations
if any([type(q.fn) in [AsymmetricLoraQuantizer, SymmetricLoraQuantizer] for q in transformations]):
model = replace_with_decompressors(model, transformations)
else:
model = replace_quantizer_to_torch_native_module(model)
model = remove_disabled_quantizers(model)
return model


def replace_with_decompressors(model: NNCFNetwork, transformations: List[Command]) -> NNCFNetwork:
"""
Returns the model with Quantizers replaced with Decompressors.
For the Quantizers containing LoRA adapters, it is important to quantize and
dequantize weights in the same manner and then quantize them using a Decompressor-friendly formula.
This approach allows us to prevent floating-point errors that can occur due to the different order of operations.
Comment on lines +205 to +209
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please be more specific, many important details are missing.

Suggested change
Returns the model with Quantizers replaced with Decompressors.
For the Quantizers containing LoRA adapters, it is important to quantize and
dequantize weights in the same manner and then quantize them using a Decompressor-friendly formula.
This approach allows us to prevent floating-point errors that can occur due to the different order of operations.
Performs transformation from fake quantize format (FQ) to dequantization one (DQ). The former takes floating-point input, quantizes and dequantizes, and returns a floating-point value, while the latter takes a quantized integer representation, dequantizes it, and outputs a floating-point result. Mathematically, both methods lead to the same outcome, but due to differences in the order of operations and rounding errors, the actual results may differ. In particular, this error can occur for values that are located in the midpoint between two quantized values ("quants"). The FQ format may round these values to one "quant", while the DQ format rounds them to another "quant". To avoid these issues, the compressed representation should be provided not by directly quantizing the input, but by quantizing a pre-processed, fake-quantized, floating-point representation.

:param model: Compressed model with Decompressors.
:return: The modified NNCF network.
"""
transformation_layout = TransformationLayout()
model = model.nncf.get_clean_shallow_copy()
graph = model.nncf.get_graph()

for command in transformations:
quantizer = command.fn

if len(command.target_points) > 1:
msg = "Command contains more than one target point!"
raise nncf.ValidationError(msg)

tp = command.target_points[0]
node_with_weight = graph.get_node_by_name(tp.target_node_name)
weight_node = get_const_node(node_with_weight, tp.input_port_id, graph)

module_name, weight_attr_name = split_const_name(weight_node.layer_attributes.name)
module = get_module_by_name(module_name, model)
original_weight = getattr(module, weight_attr_name)

original_dtype = original_weight.dtype
original_shape = original_weight.shape
original_eps = torch.finfo(original_dtype).eps

qdq_weight = quantizer.quantize(original_weight)
if hasattr(quantizer, "_lspec"):
# Special reshape for LoRA-grouped output
qdq_weight = qdq_weight.reshape(quantizer._lspec.weight_shape)
qdq_weight = qdq_weight.to(original_dtype)

if isinstance(quantizer, AsymmetricQuantizer):
input_range_safe = abs(quantizer.input_range) + quantizer.eps
input_low, input_range = TuneRange.apply(quantizer.input_low, input_range_safe, quantizer.levels)

integer_dtype = torch.uint8

input_low = input_low.to(original_dtype)
input_range = input_range.to(original_dtype)

scale = input_range / quantizer.level_high
scale = torch.where(torch.abs(scale) < original_eps, original_eps, scale)
scale = scale.to(original_dtype)

zero_point = quantizer.level_low - torch.round(input_low / scale)
zero_point = torch.clip(zero_point, quantizer.level_low, quantizer.level_high)
zero_point = zero_point.to(integer_dtype)

q_weight = qdq_weight / scale
q_weight = q_weight + zero_point
q_weight = torch.round(q_weight)
q_weight = torch.clip(q_weight, quantizer.level_low, quantizer.level_high)
q_weight = q_weight.to(integer_dtype)

if quantizer.num_bits == 8:
decompressor = INT8AsymmetricWeightsDecompressor(
scale=scale, zero_point=zero_point, result_dtype=original_dtype
)
else:
decompressor = INT4AsymmetricWeightsDecompressor(
scale=scale,
zero_point=zero_point,
compressed_weight_shape=q_weight.shape,
result_shape=original_shape,
result_dtype=original_dtype,
)

elif isinstance(quantizer, SymmetricQuantizer):
integer_dtype = torch.int8

scale = quantizer.scale / abs(quantizer.level_low)
scale = torch.where(torch.abs(scale) < original_eps, original_eps, scale)
scale = scale.to(original_dtype)

q_weight = qdq_weight / scale
q_weight = torch.round(q_weight)
q_weight = torch.clip(q_weight, quantizer.level_low, quantizer.level_high)
q_weight = q_weight.to(integer_dtype)

if quantizer.num_bits == 8:
decompressor = INT8SymmetricWeightsDecompressor(scale=scale, result_dtype=original_dtype)
else:
decompressor = INT4SymmetricWeightsDecompressor(
scale=scale,
compressed_weight_shape=q_weight.shape,
result_shape=original_shape,
result_dtype=original_dtype,
)

packed_tensor = decompressor.pack_weight(q_weight)

# sets compressed tensor
compressed_parameter = torch.nn.Parameter(packed_tensor, requires_grad=False)
setattr(module, weight_attr_name, compressed_parameter)

consumer_nodes = graph.get_next_nodes(weight_node)
if len(consumer_nodes) > 1:
for consumer_node in consumer_nodes:
consumer_module = model.nncf.get_module_by_scope(Scope.from_str(consumer_node.layer_name))
for name, param in consumer_module.named_parameters(recurse=False, remove_duplicate=False):
if id(param) == id(original_weight):
setattr(consumer_module, name, compressed_parameter)

# registry weight decompression module in the model
decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}"

# inserts the weight decompressor into the model as the post hook on the model weight
target_point = PTTargetPoint(
TargetType.OPERATOR_POST_HOOK,
target_node_name=weight_node.node_name,
)
transformation_layout.register(
PTSharedFnInsertionCommand(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[target_point],
decompressor,
decompressor_name,
)
)

return PTModelTransformer(model).transform(transformation_layout)
11 changes: 11 additions & 0 deletions tests/torch/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,3 +773,14 @@ def _check_pre_post_hooks(
assert len(actual_hooks) == len(ref_hooks)
for actual_hook, ref_hook in zip(actual_hooks, ref_hooks):
assert actual_hook is ref_hook


class LinearModel(nn.Module):
def __init__(self, input_shape=List[int]):
super().__init__()
with set_torch_seed():
self.linear = nn.Linear(input_shape[1], input_shape[0], bias=False)
self.linear.weight.data = torch.randn(input_shape) - 0.5

def forward(self, x):
return self.linear(x)
39 changes: 39 additions & 0 deletions tests/torch/quantization/test_strip.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from tests.common.quantization.data_generators import generate_sweep_data
from tests.common.quantization.data_generators import get_quant_len_by_range
from tests.torch.helpers import BasicConvTestModel
from tests.torch.helpers import LinearModel
from tests.torch.helpers import create_compressed_model_and_algo_for_test
from tests.torch.helpers import register_bn_adaptation_init_args
from tests.torch.quantization.test_functions import get_test_data
Expand Down Expand Up @@ -325,3 +326,41 @@ def test_nncf_strip_api(strip_type, do_copy):

assert isinstance(strip_model.conv.get_pre_op("0").op, FakeQuantize)
assert isinstance(strip_model.nncf.external_quantizers["/nncf_model_input_0|OUTPUT"], FakeQuantize)


@pytest.mark.parametrize(
("mode", "torch_dtype", "atol"),
(
(nncf.CompressWeightsMode.INT4_ASYM, torch.float32, 0.01),
(nncf.CompressWeightsMode.INT4_ASYM, torch.float16, 0.01),
(nncf.CompressWeightsMode.INT4_ASYM, torch.bfloat16, 0.01),
(nncf.CompressWeightsMode.INT4_SYM, torch.float32, 0.01),
(nncf.CompressWeightsMode.INT4_SYM, torch.float16, 0.01),
(nncf.CompressWeightsMode.INT4_SYM, torch.bfloat16, 0.01),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to reduce some thresholds for sym or float32?

),
)
def test_nncf_strip_lora_model(mode, torch_dtype, atol):
input_shape = [1, 16]
model = LinearModel(input_shape=input_shape)
model = model.to(torch_dtype)
with torch.no_grad():
example = torch.ones(input_shape).to(torch_dtype)
dataset = [example]

compressed_model = nncf.compress_weights(
model,
ratio=1,
group_size=4,
mode=mode,
backup_mode=None,
dataset=nncf.Dataset(dataset),
all_layers=True,
compression_format=nncf.CompressionFormat.FQ_LORA,
)

compressed_output = compressed_model(example)

strip_compressed_model = nncf.strip(compressed_model, do_copy=True)
stripped_output = strip_compressed_model(example)

assert torch.allclose(compressed_output, stripped_output, atol=atol)
Loading