-
Notifications
You must be signed in to change notification settings - Fork 247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Strip for LoRA modules #3331
base: develop
Are you sure you want to change the base?
Strip for LoRA modules #3331
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -220,6 +220,7 @@ logit | |
loglikelihoods | ||
lstmsequence | ||
lstsq | ||
lspec | ||
lyalyushkin | ||
mapillary | ||
maskrcnn | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,16 +10,35 @@ | |
# limitations under the License. | ||
|
||
|
||
from typing import List | ||
|
||
import numpy as np | ||
import torch | ||
from torch.quantization.fake_quantize import FakeQuantize | ||
|
||
import nncf | ||
from nncf.common.graph.transformations.commands import Command | ||
from nncf.common.graph.transformations.commands import TargetType | ||
from nncf.common.graph.transformations.layout import TransformationLayout | ||
from nncf.torch.dynamic_graph.scope import Scope | ||
from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType | ||
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand | ||
from nncf.torch.graph.transformations.commands import PTTargetPoint | ||
from nncf.torch.model_graph_manager import get_const_node | ||
from nncf.torch.model_graph_manager import get_module_by_name | ||
from nncf.torch.model_graph_manager import split_const_name | ||
from nncf.torch.model_transformer import PTModelTransformer | ||
from nncf.torch.nncf_network import NNCFNetwork | ||
from nncf.torch.quantization.layers import AsymmetricLoraQuantizer | ||
from nncf.torch.quantization.layers import AsymmetricQuantizer | ||
from nncf.torch.quantization.layers import BaseQuantizer | ||
from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor | ||
from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor | ||
from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor | ||
from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor | ||
from nncf.torch.quantization.layers import SymmetricLoraQuantizer | ||
from nncf.torch.quantization.layers import SymmetricQuantizer | ||
from nncf.torch.quantization.quantize_functions import TuneRange | ||
|
||
SUPPORTED_NUM_BITS_FOR_STRIP_MODEL = [8] | ||
|
||
|
@@ -171,6 +190,142 @@ def strip_quantized_model(model: NNCFNetwork): | |
:param model: Compressed model. | ||
:return: The modified NNCF network. | ||
""" | ||
model = replace_quantizer_to_torch_native_module(model) | ||
model = remove_disabled_quantizers(model) | ||
model_layout = model.nncf.transformation_layout() | ||
transformations = model_layout.transformations | ||
if any([type(q.fn) in [AsymmetricLoraQuantizer, SymmetricLoraQuantizer] for q in transformations]): | ||
model = replace_with_decompressors(model, transformations) | ||
else: | ||
model = replace_quantizer_to_torch_native_module(model) | ||
model = remove_disabled_quantizers(model) | ||
return model | ||
|
||
|
||
def replace_with_decompressors(model: NNCFNetwork, transformations: List[Command]) -> NNCFNetwork: | ||
""" | ||
Returns the model with Quantizers replaced with Decompressors. | ||
For the Quantizers containing LoRA adapters, it is important to quantize and | ||
dequantize weights in the same manner and then quantize them using a Decompressor-friendly formula. | ||
This approach allows us to prevent floating-point errors that can occur due to the different order of operations. | ||
:param model: Compressed model with Decompressors. | ||
:return: The modified NNCF network. | ||
""" | ||
transformation_layout = TransformationLayout() | ||
model = model.nncf.get_clean_shallow_copy() | ||
graph = model.nncf.get_graph() | ||
|
||
for command in transformations: | ||
quantizer = command.fn | ||
|
||
if len(command.target_points) > 1: | ||
msg = "Command contains more than one target point!" | ||
raise nncf.ValidationError(msg) | ||
|
||
tp = command.target_points[0] | ||
node_with_weight = graph.get_node_by_name(tp.target_node_name) | ||
weight_node = get_const_node(node_with_weight, tp.input_port_id, graph) | ||
|
||
module_name, weight_attr_name = split_const_name(weight_node.layer_attributes.name) | ||
module = get_module_by_name(module_name, model) | ||
original_weight = getattr(module, weight_attr_name) | ||
|
||
original_dtype = original_weight.dtype | ||
original_shape = original_weight.shape | ||
original_eps = torch.finfo(original_dtype).eps | ||
|
||
qdq_weight = quantizer.quantize(original_weight) | ||
if hasattr(quantizer, "_lspec"): | ||
# Special reshape for LoRA-grouped output | ||
qdq_weight = qdq_weight.reshape(quantizer._lspec.weight_shape) | ||
qdq_weight = qdq_weight.to(original_dtype) | ||
|
||
if isinstance(quantizer, AsymmetricQuantizer): | ||
input_range_safe = abs(quantizer.input_range) + quantizer.eps | ||
input_low, input_range = TuneRange.apply(quantizer.input_low, input_range_safe, quantizer.levels) | ||
|
||
integer_dtype = torch.uint8 | ||
|
||
input_low = input_low.to(original_dtype) | ||
input_range = input_range.to(original_dtype) | ||
|
||
scale = input_range / quantizer.level_high | ||
scale = torch.where(torch.abs(scale) < original_eps, original_eps, scale) | ||
scale = scale.to(original_dtype) | ||
|
||
zero_point = quantizer.level_low - torch.round(input_low / scale) | ||
zero_point = torch.clip(zero_point, quantizer.level_low, quantizer.level_high) | ||
zero_point = zero_point.to(integer_dtype) | ||
|
||
q_weight = qdq_weight / scale | ||
q_weight = q_weight + zero_point | ||
q_weight = torch.round(q_weight) | ||
q_weight = torch.clip(q_weight, quantizer.level_low, quantizer.level_high) | ||
q_weight = q_weight.to(integer_dtype) | ||
|
||
if quantizer.num_bits == 8: | ||
decompressor = INT8AsymmetricWeightsDecompressor( | ||
scale=scale, zero_point=zero_point, result_dtype=original_dtype | ||
) | ||
else: | ||
decompressor = INT4AsymmetricWeightsDecompressor( | ||
scale=scale, | ||
zero_point=zero_point, | ||
compressed_weight_shape=q_weight.shape, | ||
result_shape=original_shape, | ||
result_dtype=original_dtype, | ||
) | ||
|
||
elif isinstance(quantizer, SymmetricQuantizer): | ||
integer_dtype = torch.int8 | ||
|
||
scale = quantizer.scale / abs(quantizer.level_low) | ||
scale = torch.where(torch.abs(scale) < original_eps, original_eps, scale) | ||
scale = scale.to(original_dtype) | ||
|
||
q_weight = qdq_weight / scale | ||
q_weight = torch.round(q_weight) | ||
q_weight = torch.clip(q_weight, quantizer.level_low, quantizer.level_high) | ||
q_weight = q_weight.to(integer_dtype) | ||
|
||
if quantizer.num_bits == 8: | ||
decompressor = INT8SymmetricWeightsDecompressor(scale=scale, result_dtype=original_dtype) | ||
else: | ||
decompressor = INT4SymmetricWeightsDecompressor( | ||
scale=scale, | ||
compressed_weight_shape=q_weight.shape, | ||
result_shape=original_shape, | ||
result_dtype=original_dtype, | ||
) | ||
|
||
packed_tensor = decompressor.pack_weight(q_weight) | ||
|
||
# sets compressed tensor | ||
compressed_parameter = torch.nn.Parameter(packed_tensor, requires_grad=False) | ||
setattr(module, weight_attr_name, compressed_parameter) | ||
|
||
consumer_nodes = graph.get_next_nodes(weight_node) | ||
if len(consumer_nodes) > 1: | ||
for consumer_node in consumer_nodes: | ||
consumer_module = model.nncf.get_module_by_scope(Scope.from_str(consumer_node.layer_name)) | ||
for name, param in consumer_module.named_parameters(recurse=False, remove_duplicate=False): | ||
if id(param) == id(original_weight): | ||
setattr(consumer_module, name, compressed_parameter) | ||
|
||
# registry weight decompression module in the model | ||
decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}" | ||
|
||
# inserts the weight decompressor into the model as the post hook on the model weight | ||
target_point = PTTargetPoint( | ||
TargetType.OPERATOR_POST_HOOK, | ||
target_node_name=weight_node.node_name, | ||
) | ||
transformation_layout.register( | ||
PTSharedFnInsertionCommand( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please align it with latest changes(#3293) in torch_backend: |
||
[target_point], | ||
decompressor, | ||
decompressor_name, | ||
) | ||
) | ||
|
||
return PTModelTransformer(model).transform(transformation_layout) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,7 @@ | |
from tests.common.quantization.data_generators import generate_sweep_data | ||
from tests.common.quantization.data_generators import get_quant_len_by_range | ||
from tests.torch.helpers import BasicConvTestModel | ||
from tests.torch.helpers import LinearModel | ||
from tests.torch.helpers import create_compressed_model_and_algo_for_test | ||
from tests.torch.helpers import register_bn_adaptation_init_args | ||
from tests.torch.quantization.test_functions import get_test_data | ||
|
@@ -325,3 +326,41 @@ def test_nncf_strip_api(strip_type, do_copy): | |
|
||
assert isinstance(strip_model.conv.get_pre_op("0").op, FakeQuantize) | ||
assert isinstance(strip_model.nncf.external_quantizers["/nncf_model_input_0|OUTPUT"], FakeQuantize) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("mode", "torch_dtype", "atol"), | ||
( | ||
(nncf.CompressWeightsMode.INT4_ASYM, torch.float32, 0.01), | ||
(nncf.CompressWeightsMode.INT4_ASYM, torch.float16, 0.01), | ||
(nncf.CompressWeightsMode.INT4_ASYM, torch.bfloat16, 0.01), | ||
(nncf.CompressWeightsMode.INT4_SYM, torch.float32, 0.01), | ||
(nncf.CompressWeightsMode.INT4_SYM, torch.float16, 0.01), | ||
(nncf.CompressWeightsMode.INT4_SYM, torch.bfloat16, 0.01), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to reduce some thresholds for |
||
), | ||
) | ||
def test_nncf_strip_lora_model(mode, torch_dtype, atol): | ||
input_shape = [1, 16] | ||
model = LinearModel(input_shape=input_shape) | ||
model = model.to(torch_dtype) | ||
with torch.no_grad(): | ||
example = torch.ones(input_shape).to(torch_dtype) | ||
dataset = [example] | ||
|
||
compressed_model = nncf.compress_weights( | ||
model, | ||
ratio=1, | ||
group_size=4, | ||
mode=mode, | ||
backup_mode=None, | ||
dataset=nncf.Dataset(dataset), | ||
all_layers=True, | ||
compression_format=nncf.CompressionFormat.FQ_LORA, | ||
) | ||
|
||
compressed_output = compressed_model(example) | ||
|
||
strip_compressed_model = nncf.strip(compressed_model, do_copy=True) | ||
stripped_output = strip_compressed_model(example) | ||
ljaljushkin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
assert torch.allclose(compressed_output, stripped_output, atol=atol) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please be more specific, many important details are missing.