From 7b618aa14089c549f474d08d427ab761fd12d792 Mon Sep 17 00:00:00 2001
From: PrimozGodec
Date: Tue, 16 Aug 2022 16:02:26 +0200
Subject: [PATCH 1/2] Score documents - replace fasttext with sbert embedding
---
.../text/widgets/owscoredocuments.py | 70 +++++++----
.../widgets/tests/test_owscoredocuments.py | 117 +++++++++++++++++-
2 files changed, 155 insertions(+), 32 deletions(-)
diff --git a/orangecontrib/text/widgets/owscoredocuments.py b/orangecontrib/text/widgets/owscoredocuments.py
index 404e8c1d0..0c737811b 100644
--- a/orangecontrib/text/widgets/owscoredocuments.py
+++ b/orangecontrib/text/widgets/owscoredocuments.py
@@ -35,11 +35,9 @@
from orangecontrib.text import Corpus
from orangecontrib.text.preprocess import BaseNormalizer, NGrams, BaseTokenFilter
-from orangecontrib.text.vectorization.document_embedder import (
- LANGS_TO_ISO,
- DocumentEmbedder,
-)
+from orangecontrib.text.vectorization.sbert import SBERT
from orangecontrib.text.widgets.utils import enum2int
+
from orangecontrib.text.widgets.utils.words import create_words_table
@@ -69,24 +67,29 @@ def _embedding_similarity(
corpus: Corpus,
words: List[str],
callback: Callable,
- embedding_language: str,
) -> np.ndarray:
- language = LANGS_TO_ISO[embedding_language]
# make sure there will be only embeddings in X after calling the embedder
corpus = Corpus.from_table(Domain([], metas=corpus.domain.metas), corpus)
- emb = DocumentEmbedder(language)
+ emb = SBERT()
cb_part = len(corpus) / (len(corpus) + len(words))
documet_embeddings, skipped = emb.transform(
corpus, wrap_callback(callback, 0, cb_part)
)
- assert skipped is None
-
- words = [[w] for w in words]
- word_embeddings = np.array(
- emb.transform(words, wrap_callback(callback, cb_part, 1 - cb_part))
- )
- return cosine_similarity(documet_embeddings.X, word_embeddings)
+ if skipped:
+ # raise when any embedding failed. It could be also done that distances
+ # are computed only for valid embeddings, but it doesn't make sense
+ # since cases when part of documents do not embed are extremely rare
+ # usually when a network error happen embedding of all documents fail
+ raise ValueError("Some documents not embedded; try to rerun scoring")
+
+ # document embedding need corpus - changing list of words to corpus
+ w_emb = emb.embed_batches(words, batch_size=50)
+ if any(x is None for x in w_emb):
+ # raise when some words not embedded, using only valid word embedding
+ # would cause wrong results
+ raise ValueError("Some words not embedded; try to rerun scoring")
+ return cosine_similarity(documet_embeddings.X, np.array(w_emb))
SCORING_METHODS = {
@@ -108,9 +111,7 @@ def _embedding_similarity(
),
}
-ADDITIONAL_OPTIONS = {
- "embedding_similarity": ("embedding_language", list(LANGS_TO_ISO.keys()))
-}
+ADDITIONAL_OPTIONS = {}
AGGREGATIONS = {
"Mean": np.mean,
@@ -137,6 +138,7 @@ def _preprocess_words(
np.empty((len(words), 0)),
metas=np.array([[w] for w in words]),
text_features=[words_feature],
+ language=corpus.language
)
# apply all corpus preprocessors except Filter and NGrams, which change terms
# filter removes words from the term, and NGrams split the term in grams.
@@ -193,12 +195,16 @@ def callback(i: float) -> None:
scoring_method = SCORING_METHODS[sm][1]
sig = signature(scoring_method)
add_params = {k: v for k, v in additional_params.items() if k in sig.parameters}
- scs = scoring_method(
- corpus,
- words,
- wrap_callback(callback, start=(i + 1) * cb_part, end=(i + 2) * cb_part),
- **add_params
- )
+ try:
+ scs = scoring_method(
+ corpus,
+ words,
+ wrap_callback(callback, start=(i + 1) * cb_part, end=(i + 2) * cb_part),
+ **add_params
+ )
+ except ValueError as ex:
+ state.set_partial_result((sm, aggregation, str(ex)))
+ continue
scs = AGGREGATIONS[aggregation](scs, axis=1)
state.set_partial_result((sm, aggregation, scs))
@@ -328,7 +334,6 @@ class OWScoreDocuments(OWWidget, ConcurrentWidgetMixin):
word_frequency: bool = Setting(True)
word_appearance: bool = Setting(False)
embedding_similarity: bool = Setting(False)
- embedding_language: int = Setting(0)
sort_column_order: Tuple[int, int] = Setting(DEFAULT_SORTING)
selected_rows: List[int] = ContextSetting([], schema_only=True)
@@ -345,6 +350,7 @@ class Outputs:
class Warning(OWWidget.Warning):
corpus_not_normalized = Msg("Use Preprocess Text to normalize corpus.")
+ scoring_warning = Msg("{}")
class Error(OWWidget.Error):
custom_err = Msg("{}")
@@ -622,6 +628,7 @@ def __setting_changed(self) -> None:
@gui.deferred
def commit(self) -> None:
self.Error.custom_err.clear()
+ self.Warning.scoring_warning.clear()
self.cancel()
if self.corpus is not None and self.words is not None:
scorers = self._get_active_scorers()
@@ -645,10 +652,19 @@ def commit(self) -> None:
def on_done(self, _: None) -> None:
self._send_output()
- def on_partial_result(self, result: Tuple[str, str, np.ndarray]) -> None:
+ def on_partial_result(
+ self, result: Tuple[str, str, Union[np.ndarray, str]]
+ ) -> None:
sc_method, aggregation, scores = result
- self.scores[(sc_method, aggregation)] = scores
- self._fill_table()
+ if isinstance(scores, str):
+ # scoring failed with error in scores variable
+ self.Warning.scoring_warning(
+ f"{SCORING_METHODS[sc_method][0]} failed: {scores}"
+ )
+ else:
+ # scoring successful
+ self.scores[(sc_method, aggregation)] = scores
+ self._fill_table()
def on_exception(self, ex: Exception) -> None:
self.Error.custom_err(ex)
diff --git a/orangecontrib/text/widgets/tests/test_owscoredocuments.py b/orangecontrib/text/widgets/tests/test_owscoredocuments.py
index 8561551b8..ec9253c01 100644
--- a/orangecontrib/text/widgets/tests/test_owscoredocuments.py
+++ b/orangecontrib/text/widgets/tests/test_owscoredocuments.py
@@ -13,7 +13,7 @@
from Orange.widgets.tests.utils import simulate
from orangecontrib.text import Corpus, preprocess
-from orangecontrib.text.vectorization.document_embedder import _ServerEmbedder
+from orangecontrib.text.vectorization.sbert import SBERT
from orangecontrib.text.widgets.owscoredocuments import (
OWScoreDocuments,
SelectionMethods,
@@ -22,8 +22,12 @@
from orangecontrib.text.widgets.utils.words import create_words_table
-def embedding_mock(_, data, callback=None):
- return np.ones((len(data), 10))
+def embedding_mock(_, data, batch_size=None, callback=None):
+ return np.ones((len(data), 10)).tolist()
+
+
+def embedding_mock_none(_, data, batch_size=None, callback=None):
+ return np.ones((len(data) - 1, 10)).tolist() + [None]
class TestOWScoreDocuments(WidgetTest):
@@ -117,7 +121,8 @@ def test_guess_word_attribute(self):
self.send_signal(self.widget.Inputs.words, None)
self.assertIsNone(self.widget.words)
- @patch.object(_ServerEmbedder, "embedd_data", new=embedding_mock)
+ @patch.object(SBERT, "embed_batches", new=embedding_mock)
+ @patch.object(SBERT, "__call__", new=embedding_mock)
def test_change_scorer(self):
model = self.widget.model
self.send_signal(self.widget.Inputs.corpus, self.corpus)
@@ -155,6 +160,7 @@ def create_corpus(texts: List[str]) -> Corpus:
X=np.empty((len(texts), 0)),
metas=np.array(texts).reshape(-1, 1),
text_features=[text_var],
+ language="en"
)
return preprocess.LowercaseTransformer()(c)
@@ -229,7 +235,8 @@ def test_word_appearance(self):
self.assertTrue(all(isinstance(s, float) for s in scores))
self.assertListEqual(scores, [0, 0])
- @patch.object(_ServerEmbedder, "embedd_data", new=embedding_mock)
+ @patch.object(SBERT, "embed_batches", new=embedding_mock)
+ @patch.object(SBERT, "__call__", new=embedding_mock)
def test_embedding_similarity(self):
corpus = self.create_corpus(
[
@@ -453,6 +460,106 @@ def test_titles_no_newline(self):
"The Little Match-Seller test", self.widget.view.model().index(0, 0).data()
)
+ @patch.object(SBERT, "embed_batches", new=embedding_mock)
+ @patch.object(SBERT, "__call__")
+ def test_warning_unsuccessful_scoring(self, emb_mock):
+ """Test when embedding for at least one document is not successful"""
+ emb_mock.return_value = np.ones((len(self.corpus) - 1, 10)).tolist() + [None]
+
+ model = self.widget.model
+ self.send_signal(self.widget.Inputs.corpus, self.corpus)
+ self.send_signal(self.widget.Inputs.words, self.words)
+ self.wait_until_finished()
+
+ # scoring fails
+ self.widget.controls.embedding_similarity.click()
+ self.wait_until_finished()
+ self.assertEqual(2, model.columnCount()) # name and word count, no similarity
+ self.assertEqual(model.headerData(1, Qt.Horizontal), "Word count")
+ self.assertTrue(self.widget.Warning.scoring_warning.is_shown())
+ self.assertEqual(
+ "Similarity failed: Some documents not embedded; try to rerun scoring",
+ str(self.widget.Warning.scoring_warning),
+ )
+
+ # rerun without falling scoring
+ self.widget.controls.embedding_similarity.click()
+ self.wait_until_finished()
+ self.assertEqual(2, model.columnCount())
+ self.assertEqual(model.headerData(1, Qt.Horizontal), "Word count")
+ self.assertFalse(self.widget.Warning.scoring_warning.is_shown())
+
+ # run failing scoring again
+ self.widget.controls.embedding_similarity.click()
+ self.wait_until_finished()
+ self.assertEqual(2, model.columnCount()) # name and word count, no similarity
+ self.assertEqual(model.headerData(1, Qt.Horizontal), "Word count")
+ self.assertTrue(self.widget.Warning.scoring_warning.is_shown())
+ self.assertEqual(
+ "Similarity failed: Some documents not embedded; try to rerun scoring",
+ str(self.widget.Warning.scoring_warning),
+ )
+
+ # run scoring again, this time does not fail, warning should disapper
+ emb_mock.return_value = np.ones((len(self.corpus), 10)).tolist()
+ self.widget.controls.embedding_similarity.click()
+ self.widget.controls.embedding_similarity.click()
+ self.wait_until_finished()
+ self.assertEqual(3, model.columnCount()) # name and word count, no similarity
+ self.assertEqual(model.headerData(1, Qt.Horizontal), "Word count")
+ self.assertEqual(model.headerData(2, Qt.Horizontal), "Similarity")
+ self.assertFalse(self.widget.Warning.scoring_warning.is_shown())
+
+ @patch.object(SBERT, "embed_batches")
+ @patch.object(SBERT, "__call__", new=embedding_mock)
+ def test_warning_unsuccessful_scoring_words(self, emb_mock):
+ """Test when words embedding for at least one word is not successful"""
+ emb_mock.return_value = np.ones((len(self.words), 10)).tolist() + [None]
+
+ model = self.widget.model
+ self.send_signal(self.widget.Inputs.corpus, self.corpus)
+ self.send_signal(self.widget.Inputs.words, self.words)
+ self.wait_until_finished()
+
+ # scoring fails
+ self.widget.controls.embedding_similarity.click()
+ self.wait_until_finished()
+ self.assertEqual(2, model.columnCount()) # name and word count, no similarity
+ self.assertEqual(model.headerData(1, Qt.Horizontal), "Word count")
+ self.assertTrue(self.widget.Warning.scoring_warning.is_shown())
+ self.assertEqual(
+ "Similarity failed: Some words not embedded; try to rerun scoring",
+ str(self.widget.Warning.scoring_warning),
+ )
+
+ # rerun without falling scoring
+ self.widget.controls.embedding_similarity.click()
+ self.wait_until_finished()
+ self.assertEqual(2, model.columnCount())
+ self.assertEqual(model.headerData(1, Qt.Horizontal), "Word count")
+ self.assertFalse(self.widget.Warning.scoring_warning.is_shown())
+
+ # run failing scoring again
+ self.widget.controls.embedding_similarity.click()
+ self.wait_until_finished()
+ self.assertEqual(2, model.columnCount()) # name and word count, no similarity
+ self.assertEqual(model.headerData(1, Qt.Horizontal), "Word count")
+ self.assertTrue(self.widget.Warning.scoring_warning.is_shown())
+ self.assertEqual(
+ "Similarity failed: Some words not embedded; try to rerun scoring",
+ str(self.widget.Warning.scoring_warning),
+ )
+
+ # run scoring again, this time does not fail, warning should disapper
+ emb_mock.return_value = np.ones((len(self.words), 10)).tolist()
+ self.widget.controls.embedding_similarity.click()
+ self.widget.controls.embedding_similarity.click()
+ self.wait_until_finished()
+ self.assertEqual(3, model.columnCount()) # name and word count, no similarity
+ self.assertEqual(model.headerData(1, Qt.Horizontal), "Word count")
+ self.assertEqual(model.headerData(2, Qt.Horizontal), "Similarity")
+ self.assertFalse(self.widget.Warning.scoring_warning.is_shown())
+
def test_n_grams(self):
texts = [
"Lorem ipsum dolor sit ipsum consectetur adipiscing elit dolor sit eu",
From 9e01a2e2ec6411fd6a58383f56f333bef8add026 Mon Sep 17 00:00:00 2001
From: PrimozGodec
Date: Wed, 18 Jan 2023 16:16:14 +0100
Subject: [PATCH 2/2] SBERT - enable batch embedding and fix sorting
---
orangecontrib/text/tests/test_sbert.py | 63 ++++++++++++++------
orangecontrib/text/vectorization/sbert.py | 70 +++++++++++++++++++----
2 files changed, 107 insertions(+), 26 deletions(-)
diff --git a/orangecontrib/text/tests/test_sbert.py b/orangecontrib/text/tests/test_sbert.py
index 4784b3ca1..823a9cd02 100644
--- a/orangecontrib/text/tests/test_sbert.py
+++ b/orangecontrib/text/tests/test_sbert.py
@@ -1,34 +1,42 @@
+import base64
+import json
import unittest
-from unittest.mock import patch
-from collections.abc import Iterator
+import zlib
+from unittest.mock import patch, ANY
import asyncio
from orangecontrib.text.vectorization.sbert import SBERT, EMB_DIM
from orangecontrib.text import Corpus
PATCH_METHOD = 'httpx.AsyncClient.post'
-RESPONSE = [
- f'{{ "embedding": {[i] * EMB_DIM} }}'.encode()
- for i in range(9)
-]
-
+RESPONSES = {
+ t: [i] * EMB_DIM for i, t in enumerate(Corpus.from_file("deerwester").documents)
+}
+RESPONSE_NONE = RESPONSES.copy()
+RESPONSE_NONE[list(RESPONSE_NONE.keys())[-1]] = None
IDEAL_RESPONSE = [[i] * EMB_DIM for i in range(9)]
class DummyResponse:
-
def __init__(self, content):
self.content = content
-def make_dummy_post(response, sleep=0):
+def _decompress_text(instance):
+ return zlib.decompress(base64.b64decode(instance.encode("utf-8"))).decode("utf-8")
+
+
+def make_dummy_post(responses, sleep=0):
@staticmethod
async def dummy_post(url, headers, data=None, content=None):
assert data or content
await asyncio.sleep(sleep)
- return DummyResponse(
- content=next(response) if isinstance(response, Iterator) else response
- )
+ data = json.loads(content.decode("utf-8", "replace"))
+ data_ = data if isinstance(data, list) else [data]
+ texts = [_decompress_text(instance) for instance in data_]
+ responses_ = [responses[t] for t in texts]
+ r = {"embedding": responses_ if isinstance(data, list) else responses_[0]}
+ return DummyResponse(content=json.dumps(r).encode("utf-8"))
return dummy_post
@@ -51,17 +59,17 @@ def test_empty_corpus(self, mock):
dict()
)
- @patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE)))
+ @patch(PATCH_METHOD, make_dummy_post(RESPONSES))
def test_success(self):
result = self.sbert(self.corpus.documents)
self.assertEqual(result, IDEAL_RESPONSE)
- @patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE[:-1] + [None] * 3)))
+ @patch(PATCH_METHOD, make_dummy_post(RESPONSE_NONE))
def test_none_result(self):
result = self.sbert(self.corpus.documents)
self.assertEqual(result, IDEAL_RESPONSE[:-1] + [None])
- @patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE)))
+ @patch(PATCH_METHOD, make_dummy_post(RESPONSES))
def test_transform(self):
res, skipped = self.sbert.transform(self.corpus)
self.assertIsNone(skipped)
@@ -69,7 +77,7 @@ def test_transform(self):
self.assertTupleEqual(self.corpus.domain.metas, res.domain.metas)
self.assertEqual(384, len(res.domain.attributes))
- @patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE[:-1] + [None] * 3)))
+ @patch(PATCH_METHOD, make_dummy_post(RESPONSE_NONE))
def test_transform_skipped(self):
res, skipped = self.sbert.transform(self.corpus)
self.assertEqual(len(self.corpus) - 1, len(res))
@@ -80,6 +88,29 @@ def test_transform_skipped(self):
self.assertTupleEqual(self.corpus.domain.metas, skipped.domain.metas)
self.assertEqual(0, len(skipped.domain.attributes))
+ @patch(PATCH_METHOD, make_dummy_post(RESPONSES))
+ def test_batches_success(self):
+ for i in range(1, 11): # try different batch sizes
+ result = self.sbert.embed_batches(self.corpus.documents, i)
+ self.assertEqual(result, IDEAL_RESPONSE)
+
+ @patch(PATCH_METHOD, make_dummy_post(RESPONSE_NONE))
+ def test_batches_none_result(self):
+ for i in range(1, 11): # try different batch sizes
+ result = self.sbert.embed_batches(self.corpus.documents, i)
+ self.assertEqual(result, IDEAL_RESPONSE[:-1] + [None])
+
+ @patch("orangecontrib.text.vectorization.sbert._ServerCommunicator.embedd_data")
+ def test_reordered(self, mock):
+ """Test that texts are reordered according to their length"""
+ self.sbert(self.corpus.documents)
+ mock.assert_called_with(
+ tuple(sorted(self.corpus.documents, key=len, reverse=True)), callback=ANY
+ )
+
+ self.sbert([["1", "2"], ["4", "5", "6"], ["0"]])
+ mock.assert_called_with((["4", "5", "6"], ["1", "2"], ["0"]), callback=ANY)
+
if __name__ == "__main__":
unittest.main()
diff --git a/orangecontrib/text/vectorization/sbert.py b/orangecontrib/text/vectorization/sbert.py
index ec1244773..9da4b61f5 100644
--- a/orangecontrib/text/vectorization/sbert.py
+++ b/orangecontrib/text/vectorization/sbert.py
@@ -5,7 +5,7 @@
import zlib
import sys
from threading import Thread
-from typing import Any, List, Optional, Callable, Tuple
+from typing import Any, List, Optional, Callable, Tuple, Union
import numpy as np
from Orange.misc.server_embedder import ServerEmbedderCommunicator
@@ -29,18 +29,20 @@ def __init__(self) -> None:
)
def __call__(
- self, texts: List[str], callback: Callable = dummy_callback
- ) -> List[Optional[List[float]]]:
+ self, texts: List[Union[str, List[str]]], callback: Callable = dummy_callback
+ ) -> List[Union[Optional[List[float]], List[Optional[List[float]]]]]:
"""Computes embeddings for given documents.
Parameters
----------
texts
- A list of raw texts.
+ A list of texts or list of text batches (list with text)
Returns
-------
- An array of embeddings.
+ List of embeddings for each document. Each item in the list can be either
+ list of numbers (embedding) or a None when embedding fails.
+ When texts is list of batches also responses are returned in batches.
"""
if len(texts) == 0:
return []
@@ -49,7 +51,7 @@ def __call__(
# at the end and thus add extra time to the complete embedding time
sorted_texts = sorted(
enumerate(texts),
- key=lambda x: len(x[1][0]) if x[1] is not None else 0,
+ key=lambda x: len(x[1]) if x[1] is not None else 0,
reverse=True,
)
indices, sorted_texts = zip(*sorted_texts)
@@ -111,6 +113,44 @@ def _transform(
return new_corpus, skipped_corpus
+ def embed_batches(
+ self,
+ documents: List[str],
+ batch_size: int,
+ *,
+ callback: Callable = dummy_callback
+ ) -> List[Optional[List[float]]]:
+ """
+ Embed documents by sending batches of documents to the server instead of
+ sending one document per request. Using this method is suggested when
+ documents are words or extra short documents. Since they embed fast, the
+ bottleneck is sending requests to the server, and for those, it is
+ faster to send them in batches. In the case of documents with at least a
+ few sentences, the bottleneck is embedding itself. In this case, sending
+ them in separate requests can speed up embedding since the embedding
+ process can be more redistributed between workers.
+
+ Parameters
+ ----------
+ documents
+ List of document that will be sent to the server
+ batch_size
+ Number of documents in one batch sent to the server
+ callback
+ Callback for reporting the progress
+
+ Returns
+ -------
+ List of embeddings for each document. Each item in the list can be either
+ list of numbers (embedding) or a None when embedding fails.
+ """
+ batches = [
+ documents[ndx : ndx + batch_size]
+ for ndx in range(0, len(documents), batch_size)
+ ]
+ embeddings_batches = self(batches)
+ return [emb for batch in embeddings_batches for emb in batch]
+
def report(self) -> Tuple[Tuple[str, str], ...]:
"""Reports on current parameters of DocumentEmbedder.
@@ -164,10 +204,20 @@ def embedd_data(
else:
return asyncio.run(self.embedd_batch(data, callback=callback))
- async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]:
- data = base64.b64encode(
- zlib.compress(data_instance.encode("utf-8", "replace"), level=-1)
- ).decode("utf-8", "replace")
+ async def _encode_data_instance(
+ self, data_instance: Union[str, List[str]]
+ ) -> Optional[bytes]:
+ def compress_text(text):
+ return base64.b64encode(
+ zlib.compress(text.encode("utf-8", "replace"), level=-1)
+ ).decode("utf-8", "replace")
+
+ if isinstance(data_instance, str):
+ # single document in request
+ data = compress_text(data_instance)
+ else:
+ # request is batch (list of documents)
+ data = [compress_text(text) for text in data_instance]
if sys.getsizeof(data) > 500000:
# Document in corpus is too large. Size limit is 500 KB
# (after compression). - document skipped