Skip to content

Commit 252c2e3

Browse files
authored
Fix IterableStreamer in llm-chatbot with genai (#2762)
Adopt streamer generation fix from openvino.genai: openvinotoolkit/openvino.genai#1540
1 parent e0a2cea commit 252c2e3

File tree

1 file changed

+28
-18
lines changed

1 file changed

+28
-18
lines changed

notebooks/llm-chatbot/gradio_helper_genai.py

+28-18
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def get_system_prompt(model_language, system_prompt=None):
6161
return (
6262
DEFAULT_SYSTEM_PROMPT_CHINESE
6363
if (model_language == "Chinese")
64-
else DEFAULT_SYSTEM_PROMPT_JAPANESE if (model_language == "Japanese") else DEFAULT_SYSTEM_PROMPT
64+
else (DEFAULT_SYSTEM_PROMPT_JAPANESE if (model_language == "Japanese") else DEFAULT_SYSTEM_PROMPT)
6565
)
6666

6767

@@ -88,6 +88,7 @@ def __init__(self, tokenizer):
8888
self.tokens_cache = []
8989
self.text_queue = queue.Queue()
9090
self.print_len = 0
91+
self.decoded_lengths = []
9192

9293
def __iter__(self):
9394
"""
@@ -140,27 +141,32 @@ def put(self, token_id: int) -> bool:
140141
"""
141142
self.tokens_cache.append(token_id)
142143
text = self.tokenizer.decode(self.tokens_cache)
144+
self.decoded_lengths.append(len(text))
143145

144146
word = ""
147+
delay_n_tokens = 3
145148
if len(text) > self.print_len and "\n" == text[-1]:
146149
# Flush the cache after the new line symbol.
147150
word = text[self.print_len :]
148151
self.tokens_cache = []
152+
self.decoded_lengths = []
149153
self.print_len = 0
150-
elif len(text) >= 3 and text[-3:] == chr(65533):
154+
elif len(text) > 0 and text[-1] == chr(65533):
151155
# Don't print incomplete text.
152-
pass
153-
elif len(text) > self.print_len:
154-
# It is possible to have a shorter text after adding new token.
155-
# Print to output only if text length is increaesed.
156-
word = text[self.print_len :]
157-
self.print_len = len(text)
156+
self.decoded_lengths[-1] = -1
157+
elif len(self.tokens_cache) >= delay_n_tokens:
158+
print_until = self.decoded_lengths[-delay_n_tokens]
159+
if print_until != -1 and print_until > self.print_len:
160+
# It is possible to have a shorter text after adding new token.
161+
# Print to output only if text length is increased and text is complete (print_until != -1).
162+
word = text[self.print_len : print_until]
163+
self.print_len = print_until
158164
self.put_word(word)
159165

160166
if self.get_stop_flag():
161167
# When generation is stopped from streamer then end is not called, need to call it here manually.
162168
self.end()
163-
return True # True means stop generation
169+
return True # True means stop generation
164170
else:
165171
return False # False means continue generation
166172

@@ -176,23 +182,18 @@ def end(self):
176182
self.print_len = 0
177183
self.put_word(None)
178184

179-
def reset(self):
180-
self.tokens_cache = []
181-
self.text_queue = queue.Queue()
182-
self.print_len = 0
183-
184185

185186
class ChunkStreamer(IterableStreamer):
186187

187-
def __init__(self, tokenizer, tokens_len=4):
188+
def __init__(self, tokenizer, tokens_len):
188189
super().__init__(tokenizer)
189190
self.tokens_len = tokens_len
190191

191192
def put(self, token_id: int) -> bool:
192193
if (len(self.tokens_cache) + 1) % self.tokens_len != 0:
193194
self.tokens_cache.append(token_id)
195+
self.decoded_lengths.append(-1)
194196
return False
195-
sys.stdout.flush()
196197
return super().put(token_id)
197198

198199

@@ -368,7 +369,11 @@ def stop_chat_and_clear_history(streamer):
368369
interactive=True,
369370
info="Penalize repetition — 1.0 to disable.",
370371
)
371-
gr.Examples(examples, inputs=msg, label="Click on any example and press the 'Submit' button")
372+
gr.Examples(
373+
examples,
374+
inputs=msg,
375+
label="Click on any example and press the 'Submit' button",
376+
)
372377

373378
msg.submit(
374379
fn=bot,
@@ -383,6 +388,11 @@ def stop_chat_and_clear_history(streamer):
383388
queue=True,
384389
)
385390
stop.click(fn=stop_chat, inputs=streamer, outputs=[streamer], queue=False)
386-
clear.click(fn=stop_chat_and_clear_history, inputs=streamer, outputs=[chatbot, streamer], queue=False)
391+
clear.click(
392+
fn=stop_chat_and_clear_history,
393+
inputs=streamer,
394+
outputs=[chatbot, streamer],
395+
queue=False,
396+
)
387397

388398
return demo

0 commit comments

Comments
 (0)