@@ -61,7 +61,7 @@ def get_system_prompt(model_language, system_prompt=None):
61
61
return (
62
62
DEFAULT_SYSTEM_PROMPT_CHINESE
63
63
if (model_language == "Chinese" )
64
- else DEFAULT_SYSTEM_PROMPT_JAPANESE if (model_language == "Japanese" ) else DEFAULT_SYSTEM_PROMPT
64
+ else ( DEFAULT_SYSTEM_PROMPT_JAPANESE if (model_language == "Japanese" ) else DEFAULT_SYSTEM_PROMPT )
65
65
)
66
66
67
67
@@ -88,6 +88,7 @@ def __init__(self, tokenizer):
88
88
self .tokens_cache = []
89
89
self .text_queue = queue .Queue ()
90
90
self .print_len = 0
91
+ self .decoded_lengths = []
91
92
92
93
def __iter__ (self ):
93
94
"""
@@ -140,27 +141,32 @@ def put(self, token_id: int) -> bool:
140
141
"""
141
142
self .tokens_cache .append (token_id )
142
143
text = self .tokenizer .decode (self .tokens_cache )
144
+ self .decoded_lengths .append (len (text ))
143
145
144
146
word = ""
147
+ delay_n_tokens = 3
145
148
if len (text ) > self .print_len and "\n " == text [- 1 ]:
146
149
# Flush the cache after the new line symbol.
147
150
word = text [self .print_len :]
148
151
self .tokens_cache = []
152
+ self .decoded_lengths = []
149
153
self .print_len = 0
150
- elif len (text ) >= 3 and text [- 3 : ] == chr (65533 ):
154
+ elif len (text ) > 0 and text [- 1 ] == chr (65533 ):
151
155
# Don't print incomplete text.
152
- pass
153
- elif len (text ) > self .print_len :
154
- # It is possible to have a shorter text after adding new token.
155
- # Print to output only if text length is increaesed.
156
- word = text [self .print_len :]
157
- self .print_len = len (text )
156
+ self .decoded_lengths [- 1 ] = - 1
157
+ elif len (self .tokens_cache ) >= delay_n_tokens :
158
+ print_until = self .decoded_lengths [- delay_n_tokens ]
159
+ if print_until != - 1 and print_until > self .print_len :
160
+ # It is possible to have a shorter text after adding new token.
161
+ # Print to output only if text length is increased and text is complete (print_until != -1).
162
+ word = text [self .print_len : print_until ]
163
+ self .print_len = print_until
158
164
self .put_word (word )
159
165
160
166
if self .get_stop_flag ():
161
167
# When generation is stopped from streamer then end is not called, need to call it here manually.
162
168
self .end ()
163
- return True # True means stop generation
169
+ return True # True means stop generation
164
170
else :
165
171
return False # False means continue generation
166
172
@@ -176,23 +182,18 @@ def end(self):
176
182
self .print_len = 0
177
183
self .put_word (None )
178
184
179
- def reset (self ):
180
- self .tokens_cache = []
181
- self .text_queue = queue .Queue ()
182
- self .print_len = 0
183
-
184
185
185
186
class ChunkStreamer (IterableStreamer ):
186
187
187
- def __init__ (self , tokenizer , tokens_len = 4 ):
188
+ def __init__ (self , tokenizer , tokens_len ):
188
189
super ().__init__ (tokenizer )
189
190
self .tokens_len = tokens_len
190
191
191
192
def put (self , token_id : int ) -> bool :
192
193
if (len (self .tokens_cache ) + 1 ) % self .tokens_len != 0 :
193
194
self .tokens_cache .append (token_id )
195
+ self .decoded_lengths .append (- 1 )
194
196
return False
195
- sys .stdout .flush ()
196
197
return super ().put (token_id )
197
198
198
199
@@ -368,7 +369,11 @@ def stop_chat_and_clear_history(streamer):
368
369
interactive = True ,
369
370
info = "Penalize repetition — 1.0 to disable." ,
370
371
)
371
- gr .Examples (examples , inputs = msg , label = "Click on any example and press the 'Submit' button" )
372
+ gr .Examples (
373
+ examples ,
374
+ inputs = msg ,
375
+ label = "Click on any example and press the 'Submit' button" ,
376
+ )
372
377
373
378
msg .submit (
374
379
fn = bot ,
@@ -383,6 +388,11 @@ def stop_chat_and_clear_history(streamer):
383
388
queue = True ,
384
389
)
385
390
stop .click (fn = stop_chat , inputs = streamer , outputs = [streamer ], queue = False )
386
- clear .click (fn = stop_chat_and_clear_history , inputs = streamer , outputs = [chatbot , streamer ], queue = False )
391
+ clear .click (
392
+ fn = stop_chat_and_clear_history ,
393
+ inputs = streamer ,
394
+ outputs = [chatbot , streamer ],
395
+ queue = False ,
396
+ )
387
397
388
398
return demo
0 commit comments