Skip to content

Commit 2a57cc8

Browse files
Import InferRequestWrapper from optimum-intel instead of re-defining it (#1660)
1 parent 7027c38 commit 2a57cc8

File tree

1 file changed

+12
-48
lines changed

1 file changed

+12
-48
lines changed

notebooks/267-distil-whisper-asr/267-distil-whisper-asr.ipynb

+12-48
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,7 @@
924924
"### Prepare calibration datasets\n",
925925
"[back to top ⬆️](#Table-of-contents:)\n",
926926
"\n",
927-
"First step is to prepare calibration datasets for quantization. Since we quantize whisper encoder and decoder separately, we need to prepare a calibration dataset for each of the models. We define a `InferRequestWrapper` class that will intercept model inputs and collect them to a list. Then we run model inference on some small amount of audio samples. Generally, increasing the calibration dataset size improves quantization quality."
927+
"First step is to prepare calibration datasets for quantization. Since we quantize whisper encoder and decoder separately, we need to prepare a calibration dataset for each of the models. We import an `InferRequestWrapper` class that will intercept model inputs and collect them to a list. Then we run model inference on some small amount of audio samples. Generally, increasing the calibration dataset size improves quantization quality."
928928
]
929929
},
930930
{
@@ -946,44 +946,10 @@
946946
"%%skip not $to_quantize.value\n",
947947
"\n",
948948
"from itertools import islice\n",
949-
"from typing import List, Any\n",
950-
"from openvino import Tensor\n",
949+
"from optimum.intel.openvino.quantization import InferRequestWrapper\n",
951950
"\n",
952951
"\n",
953-
"class InferRequestWrapper:\n",
954-
" def __init__(self, request, data_cache: List):\n",
955-
" self.request = request\n",
956-
" self.data_cache = data_cache\n",
957-
"\n",
958-
" def __call__(self, *args, **kwargs):\n",
959-
" self.data_cache.append(*args)\n",
960-
" return self.request(*args, **kwargs)\n",
961-
"\n",
962-
" def infer(self, inputs: Any = None, shared_memory: bool = False):\n",
963-
" self.data_cache.append(inputs)\n",
964-
" return self.request.infer(inputs, shared_memory)\n",
965-
"\n",
966-
" def start_async(\n",
967-
" self,\n",
968-
" inputs: Any = None,\n",
969-
" userdata: Any = None,\n",
970-
" share_inputs: bool = False,\n",
971-
" ):\n",
972-
" self.data_cache.append(inputs)\n",
973-
" self.request.infer(inputs, share_inputs)\n",
974-
"\n",
975-
" def wait(self):\n",
976-
" pass\n",
977-
"\n",
978-
" def get_tensor(self, name: str):\n",
979-
" return Tensor(self.request.results[name])\n",
980-
"\n",
981-
" def __getattr__(self, attr):\n",
982-
" if attr in self.__dict__:\n",
983-
" return getattr(self, attr)\n",
984-
" return getattr(self.request, attr)\n",
985-
"\n",
986-
"def collect_calibration_dataset(ov_model, calibration_dataset_size):\n",
952+
"def collect_calibration_dataset(ov_model: OVModelForSpeechSeq2Seq, calibration_dataset_size: int):\n",
987953
" # Overwrite model request properties, saving the original ones for restoring later\n",
988954
" original_encoder_request = ov_model.encoder.request\n",
989955
" original_decoder_with_past_request = ov_model.decoder_with_past.request\n",
@@ -1124,25 +1090,24 @@
11241090
"import nncf\n",
11251091
"\n",
11261092
"CALIBRATION_DATASET_SIZE = 50\n",
1127-
"quantized_distil_model_path = Path(f\"{model_path}_quantized\")\n",
1093+
"quantized_model_path = Path(f\"{model_path}_quantized\")\n",
11281094
"\n",
11291095
"\n",
1130-
"def quantize(ov_model, calibration_dataset_size):\n",
1131-
" if not quantized_distil_model_path.exists():\n",
1096+
"def quantize(ov_model: OVModelForSpeechSeq2Seq, calibration_dataset_size: int):\n",
1097+
" if not quantized_model_path.exists():\n",
11321098
" encoder_calibration_data, decoder_calibration_data = collect_calibration_dataset(\n",
11331099
" ov_model, calibration_dataset_size\n",
11341100
" )\n",
11351101
" print(\"Quantizing encoder\")\n",
11361102
" quantized_encoder = nncf.quantize(\n",
11371103
" ov_model.encoder.model,\n",
11381104
" nncf.Dataset(encoder_calibration_data),\n",
1139-
" preset=nncf.QuantizationPreset.MIXED,\n",
11401105
" subset_size=len(encoder_calibration_data),\n",
11411106
" model_type=nncf.ModelType.TRANSFORMER,\n",
11421107
" # Smooth Quant algorithm reduces activation quantization error; optimal alpha value was obtained through grid search\n",
11431108
" advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=0.50)\n",
11441109
" )\n",
1145-
" ov.save_model(quantized_encoder, quantized_distil_model_path / \"openvino_encoder_model.xml\")\n",
1110+
" ov.save_model(quantized_encoder, quantized_model_path / \"openvino_encoder_model.xml\")\n",
11461111
" del quantized_encoder\n",
11471112
" del encoder_calibration_data\n",
11481113
" gc.collect()\n",
@@ -1151,23 +1116,22 @@
11511116
" quantized_decoder_with_past = nncf.quantize(\n",
11521117
" ov_model.decoder_with_past.model,\n",
11531118
" nncf.Dataset(decoder_calibration_data),\n",
1154-
" preset=nncf.QuantizationPreset.MIXED,\n",
11551119
" subset_size=len(decoder_calibration_data),\n",
11561120
" model_type=nncf.ModelType.TRANSFORMER,\n",
11571121
" # Smooth Quant algorithm reduces activation quantization error; optimal alpha value was obtained through grid search\n",
11581122
" advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=0.95)\n",
11591123
" )\n",
1160-
" ov.save_model(quantized_decoder_with_past, quantized_distil_model_path / \"openvino_decoder_with_past_model.xml\")\n",
1124+
" ov.save_model(quantized_decoder_with_past, quantized_model_path / \"openvino_decoder_with_past_model.xml\")\n",
11611125
" del quantized_decoder_with_past\n",
11621126
" del decoder_calibration_data\n",
11631127
" gc.collect()\n",
11641128
"\n",
11651129
" # Copy the config file and the first-step-decoder manually\n",
1166-
" shutil.copy(model_path / \"config.json\", quantized_distil_model_path / \"config.json\")\n",
1167-
" shutil.copy(model_path / \"openvino_decoder_model.xml\", quantized_distil_model_path / \"openvino_decoder_model.xml\")\n",
1168-
" shutil.copy(model_path / \"openvino_decoder_model.bin\", quantized_distil_model_path / \"openvino_decoder_model.bin\")\n",
1130+
" shutil.copy(model_path / \"config.json\", quantized_model_path / \"config.json\")\n",
1131+
" shutil.copy(model_path / \"openvino_decoder_model.xml\", quantized_model_path / \"openvino_decoder_model.xml\")\n",
1132+
" shutil.copy(model_path / \"openvino_decoder_model.bin\", quantized_model_path / \"openvino_decoder_model.bin\")\n",
11691133
"\n",
1170-
" quantized_ov_model = OVModelForSpeechSeq2Seq.from_pretrained(quantized_distil_model_path, ov_config=ov_config, compile=False)\n",
1134+
" quantized_ov_model = OVModelForSpeechSeq2Seq.from_pretrained(quantized_model_path, ov_config=ov_config, compile=False)\n",
11711135
" quantized_ov_model.to(device.value)\n",
11721136
" quantized_ov_model.compile()\n",
11731137
" return quantized_ov_model\n",

0 commit comments

Comments
 (0)