|
2 | 2 | import openvino_genai as ov_genai
|
3 | 3 | from uuid import uuid4
|
4 | 4 | from threading import Event, Thread
|
5 |
| -import queue |
6 |
| -import sys |
| 5 | +from gena_helper import ChunkStreamer |
7 | 6 |
|
8 | 7 | max_new_tokens = 256
|
9 | 8 |
|
@@ -65,137 +64,6 @@ def get_system_prompt(model_language, system_prompt=None):
|
65 | 64 | )
|
66 | 65 |
|
67 | 66 |
|
68 |
| -class IterableStreamer(ov_genai.StreamerBase): |
69 |
| - """ |
70 |
| - A custom streamer class for handling token streaming and detokenization with buffering. |
71 |
| -
|
72 |
| - Attributes: |
73 |
| - tokenizer (Tokenizer): The tokenizer used for encoding and decoding tokens. |
74 |
| - tokens_cache (list): A buffer to accumulate tokens for detokenization. |
75 |
| - text_queue (Queue): A synchronized queue for storing decoded text chunks. |
76 |
| - print_len (int): The length of the printed text to manage incremental decoding. |
77 |
| - """ |
78 |
| - |
79 |
| - def __init__(self, tokenizer): |
80 |
| - """ |
81 |
| - Initializes the IterableStreamer with the given tokenizer. |
82 |
| -
|
83 |
| - Args: |
84 |
| - tokenizer (Tokenizer): The tokenizer to use for encoding and decoding tokens. |
85 |
| - """ |
86 |
| - super().__init__() |
87 |
| - self.tokenizer = tokenizer |
88 |
| - self.tokens_cache = [] |
89 |
| - self.text_queue = queue.Queue() |
90 |
| - self.print_len = 0 |
91 |
| - self.decoded_lengths = [] |
92 |
| - |
93 |
| - def __iter__(self): |
94 |
| - """ |
95 |
| - Returns the iterator object itself. |
96 |
| - """ |
97 |
| - return self |
98 |
| - |
99 |
| - def __next__(self): |
100 |
| - """ |
101 |
| - Returns the next value from the text queue. |
102 |
| -
|
103 |
| - Returns: |
104 |
| - str: The next decoded text chunk. |
105 |
| -
|
106 |
| - Raises: |
107 |
| - StopIteration: If there are no more elements in the queue. |
108 |
| - """ |
109 |
| - value = self.text_queue.get() # get() will be blocked until a token is available. |
110 |
| - if value is None: |
111 |
| - raise StopIteration |
112 |
| - return value |
113 |
| - |
114 |
| - def get_stop_flag(self): |
115 |
| - """ |
116 |
| - Checks whether the generation process should be stopped. |
117 |
| -
|
118 |
| - Returns: |
119 |
| - bool: Always returns False in this implementation. |
120 |
| - """ |
121 |
| - return False |
122 |
| - |
123 |
| - def put_word(self, word: str): |
124 |
| - """ |
125 |
| - Puts a word into the text queue. |
126 |
| -
|
127 |
| - Args: |
128 |
| - word (str): The word to put into the queue. |
129 |
| - """ |
130 |
| - self.text_queue.put(word) |
131 |
| - |
132 |
| - def put(self, token_id: int) -> bool: |
133 |
| - """ |
134 |
| - Processes a token and manages the decoding buffer. Adds decoded text to the queue. |
135 |
| -
|
136 |
| - Args: |
137 |
| - token_id (int): The token_id to process. |
138 |
| -
|
139 |
| - Returns: |
140 |
| - bool: True if generation should be stopped, False otherwise. |
141 |
| - """ |
142 |
| - self.tokens_cache.append(token_id) |
143 |
| - text = self.tokenizer.decode(self.tokens_cache) |
144 |
| - self.decoded_lengths.append(len(text)) |
145 |
| - |
146 |
| - word = "" |
147 |
| - delay_n_tokens = 3 |
148 |
| - if len(text) > self.print_len and "\n" == text[-1]: |
149 |
| - # Flush the cache after the new line symbol. |
150 |
| - word = text[self.print_len :] |
151 |
| - self.tokens_cache = [] |
152 |
| - self.decoded_lengths = [] |
153 |
| - self.print_len = 0 |
154 |
| - elif len(text) > 0 and text[-1] == chr(65533): |
155 |
| - # Don't print incomplete text. |
156 |
| - self.decoded_lengths[-1] = -1 |
157 |
| - elif len(self.tokens_cache) >= delay_n_tokens: |
158 |
| - print_until = self.decoded_lengths[-delay_n_tokens] |
159 |
| - if print_until != -1 and print_until > self.print_len: |
160 |
| - # It is possible to have a shorter text after adding new token. |
161 |
| - # Print to output only if text length is increased and text is complete (print_until != -1). |
162 |
| - word = text[self.print_len : print_until] |
163 |
| - self.print_len = print_until |
164 |
| - self.put_word(word) |
165 |
| - |
166 |
| - if self.get_stop_flag(): |
167 |
| - # When generation is stopped from streamer then end is not called, need to call it here manually. |
168 |
| - self.end() |
169 |
| - return True # True means stop generation |
170 |
| - else: |
171 |
| - return False # False means continue generation |
172 |
| - |
173 |
| - def end(self): |
174 |
| - """ |
175 |
| - Flushes residual tokens from the buffer and puts a None value in the queue to signal the end. |
176 |
| - """ |
177 |
| - text = self.tokenizer.decode(self.tokens_cache) |
178 |
| - if len(text) > self.print_len: |
179 |
| - word = text[self.print_len :] |
180 |
| - self.put_word(word) |
181 |
| - self.tokens_cache = [] |
182 |
| - self.print_len = 0 |
183 |
| - self.put_word(None) |
184 |
| - |
185 |
| - |
186 |
| -class ChunkStreamer(IterableStreamer): |
187 |
| - |
188 |
| - def __init__(self, tokenizer, tokens_len): |
189 |
| - super().__init__(tokenizer) |
190 |
| - self.tokens_len = tokens_len |
191 |
| - |
192 |
| - def put(self, token_id: int) -> bool: |
193 |
| - if (len(self.tokens_cache) + 1) % self.tokens_len != 0: |
194 |
| - self.tokens_cache.append(token_id) |
195 |
| - self.decoded_lengths.append(-1) |
196 |
| - return False |
197 |
| - return super().put(token_id) |
198 |
| - |
199 | 67 |
|
200 | 68 | def make_demo(pipe, model_configuration, model_id, model_language, disable_advanced=False):
|
201 | 69 | import gradio as gr
|
|
0 commit comments