Skip to content

Commit 41824aa

Browse files
committed
SBERT - enable batch embedding and fix sorting
1 parent 1366631 commit 41824aa

File tree

2 files changed

+107
-26
lines changed

2 files changed

+107
-26
lines changed

orangecontrib/text/tests/test_sbert.py

+47-16
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,42 @@
1+
import base64
2+
import json
13
import unittest
2-
from unittest.mock import patch
3-
from collections.abc import Iterator
4+
import zlib
5+
from unittest.mock import patch, ANY
46
import asyncio
57

68
from orangecontrib.text.vectorization.sbert import SBERT, EMB_DIM
79
from orangecontrib.text import Corpus
810

911
PATCH_METHOD = 'httpx.AsyncClient.post'
10-
RESPONSE = [
11-
f'{{ "embedding": {[i] * EMB_DIM} }}'.encode()
12-
for i in range(9)
13-
]
14-
12+
RESPONSES = {
13+
t: [i] * EMB_DIM for i, t in enumerate(Corpus.from_file("deerwester").documents)
14+
}
15+
RESPONSE_NONE = RESPONSES.copy()
16+
RESPONSE_NONE[list(RESPONSE_NONE.keys())[-1]] = None
1517
IDEAL_RESPONSE = [[i] * EMB_DIM for i in range(9)]
1618

1719

1820
class DummyResponse:
19-
2021
def __init__(self, content):
2122
self.content = content
2223

2324

24-
def make_dummy_post(response, sleep=0):
25+
def _decompress_text(instance):
26+
return zlib.decompress(base64.b64decode(instance.encode("utf-8"))).decode("utf-8")
27+
28+
29+
def make_dummy_post(responses, sleep=0):
2530
@staticmethod
2631
async def dummy_post(url, headers, data=None, content=None):
2732
assert data or content
2833
await asyncio.sleep(sleep)
29-
return DummyResponse(
30-
content=next(response) if isinstance(response, Iterator) else response
31-
)
34+
data = json.loads(content.decode("utf-8", "replace"))
35+
data_ = data if isinstance(data, list) else [data]
36+
texts = [_decompress_text(instance) for instance in data_]
37+
responses_ = [responses[t] for t in texts]
38+
r = {"embedding": responses_ if isinstance(data, list) else responses_[0]}
39+
return DummyResponse(content=json.dumps(r).encode("utf-8"))
3240
return dummy_post
3341

3442

@@ -51,25 +59,25 @@ def test_empty_corpus(self, mock):
5159
dict()
5260
)
5361

54-
@patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE)))
62+
@patch(PATCH_METHOD, make_dummy_post(RESPONSES))
5563
def test_success(self):
5664
result = self.sbert(self.corpus.documents)
5765
self.assertEqual(result, IDEAL_RESPONSE)
5866

59-
@patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE[:-1] + [None] * 3)))
67+
@patch(PATCH_METHOD, make_dummy_post(RESPONSE_NONE))
6068
def test_none_result(self):
6169
result = self.sbert(self.corpus.documents)
6270
self.assertEqual(result, IDEAL_RESPONSE[:-1] + [None])
6371

64-
@patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE)))
72+
@patch(PATCH_METHOD, make_dummy_post(RESPONSES))
6573
def test_transform(self):
6674
res, skipped = self.sbert.transform(self.corpus)
6775
self.assertIsNone(skipped)
6876
self.assertEqual(len(self.corpus), len(res))
6977
self.assertTupleEqual(self.corpus.domain.metas, res.domain.metas)
7078
self.assertEqual(384, len(res.domain.attributes))
7179

72-
@patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE[:-1] + [None] * 3)))
80+
@patch(PATCH_METHOD, make_dummy_post(RESPONSE_NONE))
7381
def test_transform_skipped(self):
7482
res, skipped = self.sbert.transform(self.corpus)
7583
self.assertEqual(len(self.corpus) - 1, len(res))
@@ -80,6 +88,29 @@ def test_transform_skipped(self):
8088
self.assertTupleEqual(self.corpus.domain.metas, skipped.domain.metas)
8189
self.assertEqual(0, len(skipped.domain.attributes))
8290

91+
@patch(PATCH_METHOD, make_dummy_post(RESPONSES))
92+
def test_batches_success(self):
93+
for i in range(1, 11): # try different batch sizes
94+
result = self.sbert.embed_batches(self.corpus.documents, i)
95+
self.assertEqual(result, IDEAL_RESPONSE)
96+
97+
@patch(PATCH_METHOD, make_dummy_post(RESPONSE_NONE))
98+
def test_batches_none_result(self):
99+
for i in range(1, 11): # try different batch sizes
100+
result = self.sbert.embed_batches(self.corpus.documents, i)
101+
self.assertEqual(result, IDEAL_RESPONSE[:-1] + [None])
102+
103+
@patch("orangecontrib.text.vectorization.sbert._ServerCommunicator.embedd_data")
104+
def test_reordered(self, mock):
105+
"""Test that texts are reordered according to their length"""
106+
self.sbert(self.corpus.documents)
107+
mock.assert_called_with(
108+
tuple(sorted(self.corpus.documents, key=len, reverse=True)), callback=ANY
109+
)
110+
111+
self.sbert([["1", "2"], ["4", "5", "6"], ["0"]])
112+
mock.assert_called_with((["4", "5", "6"], ["1", "2"], ["0"]), callback=ANY)
113+
83114

84115
if __name__ == "__main__":
85116
unittest.main()

orangecontrib/text/vectorization/sbert.py

+60-10
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import zlib
66
import sys
77
from threading import Thread
8-
from typing import Any, List, Optional, Callable, Tuple
8+
from typing import Any, List, Optional, Callable, Tuple, Union
99

1010
import numpy as np
1111
from Orange.misc.server_embedder import ServerEmbedderCommunicator
@@ -29,18 +29,20 @@ def __init__(self) -> None:
2929
)
3030

3131
def __call__(
32-
self, texts: List[str], callback: Callable = dummy_callback
33-
) -> List[Optional[List[float]]]:
32+
self, texts: List[Union[str, List[str]]], callback: Callable = dummy_callback
33+
) -> List[Union[Optional[List[float]], List[Optional[List[float]]]]]:
3434
"""Computes embeddings for given documents.
3535
3636
Parameters
3737
----------
3838
texts
39-
A list of raw texts.
39+
A list of texts or list of text batches (list with text)
4040
4141
Returns
4242
-------
43-
An array of embeddings.
43+
List of embeddings for each document. Each item in the list can be either
44+
list of numbers (embedding) or a None when embedding fails.
45+
When texts is list of batches also responses are returned in batches.
4446
"""
4547
if len(texts) == 0:
4648
return []
@@ -49,7 +51,7 @@ def __call__(
4951
# at the end and thus add extra time to the complete embedding time
5052
sorted_texts = sorted(
5153
enumerate(texts),
52-
key=lambda x: len(x[1][0]) if x[1] is not None else 0,
54+
key=lambda x: len(x[1]) if x[1] is not None else 0,
5355
reverse=True,
5456
)
5557
indices, sorted_texts = zip(*sorted_texts)
@@ -111,6 +113,44 @@ def _transform(
111113

112114
return new_corpus, skipped_corpus
113115

116+
def embed_batches(
117+
self,
118+
documents: List[str],
119+
batch_size: int,
120+
*,
121+
callback: Callable = dummy_callback
122+
) -> List[Optional[List[float]]]:
123+
"""
124+
Embed documents by sending batches of documents to the server instead of
125+
sending one document per request. Using this method is suggested when
126+
documents are words or extra short documents. Since they embed fast, the
127+
bottleneck is sending requests to the server, and for those, it is
128+
faster to send them in batches. In the case of documents with at least a
129+
few sentences, the bottleneck is embedding itself. In this case, sending
130+
them in separate requests can speed up embedding since the embedding
131+
process can be more redistributed between workers.
132+
133+
Parameters
134+
----------
135+
documents
136+
List of document that will be sent to the server
137+
batch_size
138+
Number of documents in one batch sent to the server
139+
callback
140+
Callback for reporting the progress
141+
142+
Returns
143+
-------
144+
List of embeddings for each document. Each item in the list can be either
145+
list of numbers (embedding) or a None when embedding fails.
146+
"""
147+
batches = [
148+
documents[ndx : ndx + batch_size]
149+
for ndx in range(0, len(documents), batch_size)
150+
]
151+
embeddings_batches = self(batches)
152+
return [emb for batch in embeddings_batches for emb in batch]
153+
114154
def report(self) -> Tuple[Tuple[str, str], ...]:
115155
"""Reports on current parameters of DocumentEmbedder.
116156
@@ -164,10 +204,20 @@ def embedd_data(
164204
else:
165205
return asyncio.run(self.embedd_batch(data, callback=callback))
166206

167-
async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]:
168-
data = base64.b64encode(
169-
zlib.compress(data_instance.encode("utf-8", "replace"), level=-1)
170-
).decode("utf-8", "replace")
207+
async def _encode_data_instance(
208+
self, data_instance: Union[str, List[str]]
209+
) -> Optional[bytes]:
210+
def compress_text(text):
211+
return base64.b64encode(
212+
zlib.compress(text.encode("utf-8", "replace"), level=-1)
213+
).decode("utf-8", "replace")
214+
215+
if isinstance(data_instance, str):
216+
# single document in request
217+
data = compress_text(data_instance)
218+
else:
219+
# request is batch (list of documents)
220+
data = [compress_text(text) for text in data_instance]
171221
if sys.getsizeof(data) > 500000:
172222
# Document in corpus is too large. Size limit is 500 KB
173223
# (after compression). - document skipped

0 commit comments

Comments
 (0)