Skip to content

Commit 3bfbab5

Browse files
StaticLLMPipeline dangling models hotfix (#693)
1 parent 406393f commit 3bfbab5

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

src/cpp/src/llm_pipeline_static.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -144,26 +144,26 @@ StaticLLMPipeline::StaticLLMPipeline(
144144
*/
145145
ov::Core core;
146146
// (1) Read the template model - this will be kvcache model
147-
auto kvcache_model = core.read_model(path / "openvino_model.xml");
147+
m_kvcache_model = core.read_model(path / "openvino_model.xml");
148148
// (2) Expose KV-cache input and output layers from kvcache model
149-
ov::pass::StatefulToStateless().run_on_model(kvcache_model);
149+
ov::pass::StatefulToStateless().run_on_model(m_kvcache_model);
150150
// (3) Clone the model - this will be prefill
151-
auto prefill_model = kvcache_model->clone();
152-
prefill_model->set_friendly_name(kvcache_model->get_friendly_name() + "_prefill");
151+
m_prefill_model = m_kvcache_model->clone();
152+
m_prefill_model->set_friendly_name(m_kvcache_model->get_friendly_name() + "_prefill");
153153
// (4) Reshape both models to static shape
154154
m_kvcache_desc = KVCacheDesc { 1024u, 0u };
155155
const uint32_t max_prompt_size = m_kvcache_desc.total_size;
156156
const uint32_t max_kvcache_size = m_kvcache_desc.total_size;
157-
reshape_to_static(prefill_model, max_prompt_size, max_kvcache_size);
158-
reshape_to_static(kvcache_model, 1u, max_kvcache_size);
157+
reshape_to_static(m_prefill_model, max_prompt_size, max_kvcache_size);
158+
reshape_to_static(m_kvcache_model, 1u, max_kvcache_size);
159159
// (5) Add slices to kvcache model
160-
kvcache_model = add_slices_to_kvcache_inputs(kvcache_model);
160+
m_kvcache_model = add_slices_to_kvcache_inputs(m_kvcache_model);
161161
// (6) Compile both model
162162
m_prefill_request = core.compile_model(
163-
prefill_model, device, extract_config_or_default(config, "PREFILL_CONFIG")
163+
m_prefill_model, device, extract_config_or_default(config, "PREFILL_CONFIG")
164164
).create_infer_request();
165165
m_kvcache_request = core.compile_model(
166-
kvcache_model, device, extract_config_or_default(config, "GENERATE_CONFIG")
166+
m_kvcache_model, device, extract_config_or_default(config, "GENERATE_CONFIG")
167167
).create_infer_request();
168168
// (7) Initialize tensors
169169
prepare_for_new_conversation();

src/cpp/src/llm_pipeline_static.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ class StaticLLMPipeline final : public LLMPipelineImplBase {
4646
uint32_t num_stored_tokens;
4747
};
4848

49+
// FIXME: Ideally, we don't need to keep those
50+
std::shared_ptr<ov::Model> m_kvcache_model;
51+
std::shared_ptr<ov::Model> m_prefill_model;
52+
4953
KVCacheDesc m_kvcache_desc;
5054
ov::InferRequest m_kvcache_request;
5155
ov::InferRequest m_prefill_request;

0 commit comments

Comments
 (0)