Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Score Documents - Use SBERT embedding instead of FastText #930

Merged
merged 2 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 47 additions & 16 deletions orangecontrib/text/tests/test_sbert.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,42 @@
import base64
import json
import unittest
from unittest.mock import patch
from collections.abc import Iterator
import zlib
from unittest.mock import patch, ANY
import asyncio

from orangecontrib.text.vectorization.sbert import SBERT, EMB_DIM
from orangecontrib.text import Corpus

PATCH_METHOD = 'httpx.AsyncClient.post'
RESPONSE = [
f'{{ "embedding": {[i] * EMB_DIM} }}'.encode()
for i in range(9)
]

RESPONSES = {
t: [i] * EMB_DIM for i, t in enumerate(Corpus.from_file("deerwester").documents)
}
RESPONSE_NONE = RESPONSES.copy()
RESPONSE_NONE[list(RESPONSE_NONE.keys())[-1]] = None
IDEAL_RESPONSE = [[i] * EMB_DIM for i in range(9)]


class DummyResponse:

def __init__(self, content):
self.content = content


def make_dummy_post(response, sleep=0):
def _decompress_text(instance):
return zlib.decompress(base64.b64decode(instance.encode("utf-8"))).decode("utf-8")


def make_dummy_post(responses, sleep=0):
@staticmethod
async def dummy_post(url, headers, data=None, content=None):
assert data or content
await asyncio.sleep(sleep)
return DummyResponse(
content=next(response) if isinstance(response, Iterator) else response
)
data = json.loads(content.decode("utf-8", "replace"))
data_ = data if isinstance(data, list) else [data]
texts = [_decompress_text(instance) for instance in data_]
responses_ = [responses[t] for t in texts]
r = {"embedding": responses_ if isinstance(data, list) else responses_[0]}
return DummyResponse(content=json.dumps(r).encode("utf-8"))
return dummy_post


Expand All @@ -51,25 +59,25 @@ def test_empty_corpus(self, mock):
dict()
)

@patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE)))
@patch(PATCH_METHOD, make_dummy_post(RESPONSES))
def test_success(self):
result = self.sbert(self.corpus.documents)
self.assertEqual(result, IDEAL_RESPONSE)

@patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE[:-1] + [None] * 3)))
@patch(PATCH_METHOD, make_dummy_post(RESPONSE_NONE))
def test_none_result(self):
result = self.sbert(self.corpus.documents)
self.assertEqual(result, IDEAL_RESPONSE[:-1] + [None])

@patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE)))
@patch(PATCH_METHOD, make_dummy_post(RESPONSES))
def test_transform(self):
res, skipped = self.sbert.transform(self.corpus)
self.assertIsNone(skipped)
self.assertEqual(len(self.corpus), len(res))
self.assertTupleEqual(self.corpus.domain.metas, res.domain.metas)
self.assertEqual(384, len(res.domain.attributes))

@patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE[:-1] + [None] * 3)))
@patch(PATCH_METHOD, make_dummy_post(RESPONSE_NONE))
def test_transform_skipped(self):
res, skipped = self.sbert.transform(self.corpus)
self.assertEqual(len(self.corpus) - 1, len(res))
Expand All @@ -80,6 +88,29 @@ def test_transform_skipped(self):
self.assertTupleEqual(self.corpus.domain.metas, skipped.domain.metas)
self.assertEqual(0, len(skipped.domain.attributes))

@patch(PATCH_METHOD, make_dummy_post(RESPONSES))
def test_batches_success(self):
for i in range(1, 11): # try different batch sizes
result = self.sbert.embed_batches(self.corpus.documents, i)
self.assertEqual(result, IDEAL_RESPONSE)

@patch(PATCH_METHOD, make_dummy_post(RESPONSE_NONE))
def test_batches_none_result(self):
for i in range(1, 11): # try different batch sizes
result = self.sbert.embed_batches(self.corpus.documents, i)
self.assertEqual(result, IDEAL_RESPONSE[:-1] + [None])

@patch("orangecontrib.text.vectorization.sbert._ServerCommunicator.embedd_data")
def test_reordered(self, mock):
"""Test that texts are reordered according to their length"""
self.sbert(self.corpus.documents)
mock.assert_called_with(
tuple(sorted(self.corpus.documents, key=len, reverse=True)), callback=ANY
)

self.sbert([["1", "2"], ["4", "5", "6"], ["0"]])
mock.assert_called_with((["4", "5", "6"], ["1", "2"], ["0"]), callback=ANY)


if __name__ == "__main__":
unittest.main()
70 changes: 60 additions & 10 deletions orangecontrib/text/vectorization/sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import zlib
import sys
from threading import Thread
from typing import Any, List, Optional, Callable, Tuple
from typing import Any, List, Optional, Callable, Tuple, Union

import numpy as np
from Orange.misc.server_embedder import ServerEmbedderCommunicator
Expand All @@ -29,18 +29,20 @@ def __init__(self) -> None:
)

def __call__(
self, texts: List[str], callback: Callable = dummy_callback
) -> List[Optional[List[float]]]:
self, texts: List[Union[str, List[str]]], callback: Callable = dummy_callback
) -> List[Union[Optional[List[float]], List[Optional[List[float]]]]]:
"""Computes embeddings for given documents.

Parameters
----------
texts
A list of raw texts.
A list of texts or list of text batches (list with text)

Returns
-------
An array of embeddings.
List of embeddings for each document. Each item in the list can be either
list of numbers (embedding) or a None when embedding fails.
When texts is list of batches also responses are returned in batches.
"""
if len(texts) == 0:
return []
Expand All @@ -49,7 +51,7 @@ def __call__(
# at the end and thus add extra time to the complete embedding time
sorted_texts = sorted(
enumerate(texts),
key=lambda x: len(x[1][0]) if x[1] is not None else 0,
key=lambda x: len(x[1]) if x[1] is not None else 0,
reverse=True,
)
indices, sorted_texts = zip(*sorted_texts)
Expand Down Expand Up @@ -111,6 +113,44 @@ def _transform(

return new_corpus, skipped_corpus

def embed_batches(
self,
documents: List[str],
batch_size: int,
*,
callback: Callable = dummy_callback
) -> List[Optional[List[float]]]:
"""
Embed documents by sending batches of documents to the server instead of
sending one document per request. Using this method is suggested when
documents are words or extra short documents. Since they embed fast, the
bottleneck is sending requests to the server, and for those, it is
faster to send them in batches. In the case of documents with at least a
few sentences, the bottleneck is embedding itself. In this case, sending
them in separate requests can speed up embedding since the embedding
process can be more redistributed between workers.

Parameters
----------
documents
List of document that will be sent to the server
batch_size
Number of documents in one batch sent to the server
callback
Callback for reporting the progress

Returns
-------
List of embeddings for each document. Each item in the list can be either
list of numbers (embedding) or a None when embedding fails.
"""
batches = [
documents[ndx : ndx + batch_size]
for ndx in range(0, len(documents), batch_size)
]
embeddings_batches = self(batches)
return [emb for batch in embeddings_batches for emb in batch]

def report(self) -> Tuple[Tuple[str, str], ...]:
"""Reports on current parameters of DocumentEmbedder.

Expand Down Expand Up @@ -164,10 +204,20 @@ def embedd_data(
else:
return asyncio.run(self.embedd_batch(data, callback=callback))

async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]:
data = base64.b64encode(
zlib.compress(data_instance.encode("utf-8", "replace"), level=-1)
).decode("utf-8", "replace")
async def _encode_data_instance(
self, data_instance: Union[str, List[str]]
) -> Optional[bytes]:
def compress_text(text):
return base64.b64encode(
zlib.compress(text.encode("utf-8", "replace"), level=-1)
).decode("utf-8", "replace")

if isinstance(data_instance, str):
# single document in request
data = compress_text(data_instance)
else:
# request is batch (list of documents)
data = [compress_text(text) for text in data_instance]
if sys.getsizeof(data) > 500000:
# Document in corpus is too large. Size limit is 500 KB
# (after compression). - document skipped
Expand Down
70 changes: 43 additions & 27 deletions orangecontrib/text/widgets/owscoredocuments.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,9 @@

from orangecontrib.text import Corpus
from orangecontrib.text.preprocess import BaseNormalizer, NGrams, BaseTokenFilter
from orangecontrib.text.vectorization.document_embedder import (
LANGS_TO_ISO,
DocumentEmbedder,
)
from orangecontrib.text.vectorization.sbert import SBERT
from orangecontrib.text.widgets.utils import enum2int

from orangecontrib.text.widgets.utils.words import create_words_table


Expand Down Expand Up @@ -69,24 +67,29 @@ def _embedding_similarity(
corpus: Corpus,
words: List[str],
callback: Callable,
embedding_language: str,
) -> np.ndarray:
language = LANGS_TO_ISO[embedding_language]
# make sure there will be only embeddings in X after calling the embedder
corpus = Corpus.from_table(Domain([], metas=corpus.domain.metas), corpus)
emb = DocumentEmbedder(language)
emb = SBERT()

cb_part = len(corpus) / (len(corpus) + len(words))
documet_embeddings, skipped = emb.transform(
corpus, wrap_callback(callback, 0, cb_part)
)
assert skipped is None

words = [[w] for w in words]
word_embeddings = np.array(
emb.transform(words, wrap_callback(callback, cb_part, 1 - cb_part))
)
return cosine_similarity(documet_embeddings.X, word_embeddings)
if skipped:
# raise when any embedding failed. It could be also done that distances
# are computed only for valid embeddings, but it doesn't make sense
# since cases when part of documents do not embed are extremely rare
# usually when a network error happen embedding of all documents fail
raise ValueError("Some documents not embedded; try to rerun scoring")

# document embedding need corpus - changing list of words to corpus
w_emb = emb.embed_batches(words, batch_size=50)
if any(x is None for x in w_emb):
# raise when some words not embedded, using only valid word embedding
# would cause wrong results
raise ValueError("Some words not embedded; try to rerun scoring")
return cosine_similarity(documet_embeddings.X, np.array(w_emb))


SCORING_METHODS = {
Expand All @@ -108,9 +111,7 @@ def _embedding_similarity(
),
}

ADDITIONAL_OPTIONS = {
"embedding_similarity": ("embedding_language", list(LANGS_TO_ISO.keys()))
}
ADDITIONAL_OPTIONS = {}

AGGREGATIONS = {
"Mean": np.mean,
Expand All @@ -137,6 +138,7 @@ def _preprocess_words(
np.empty((len(words), 0)),
metas=np.array([[w] for w in words]),
text_features=[words_feature],
language=corpus.language
)
# apply all corpus preprocessors except Filter and NGrams, which change terms
# filter removes words from the term, and NGrams split the term in grams.
Expand Down Expand Up @@ -193,12 +195,16 @@ def callback(i: float) -> None:
scoring_method = SCORING_METHODS[sm][1]
sig = signature(scoring_method)
add_params = {k: v for k, v in additional_params.items() if k in sig.parameters}
scs = scoring_method(
corpus,
words,
wrap_callback(callback, start=(i + 1) * cb_part, end=(i + 2) * cb_part),
**add_params
)
try:
scs = scoring_method(
corpus,
words,
wrap_callback(callback, start=(i + 1) * cb_part, end=(i + 2) * cb_part),
**add_params
)
except ValueError as ex:
state.set_partial_result((sm, aggregation, str(ex)))
continue
scs = AGGREGATIONS[aggregation](scs, axis=1)
state.set_partial_result((sm, aggregation, scs))

Expand Down Expand Up @@ -328,7 +334,6 @@ class OWScoreDocuments(OWWidget, ConcurrentWidgetMixin):
word_frequency: bool = Setting(True)
word_appearance: bool = Setting(False)
embedding_similarity: bool = Setting(False)
embedding_language: int = Setting(0)

sort_column_order: Tuple[int, int] = Setting(DEFAULT_SORTING)
selected_rows: List[int] = ContextSetting([], schema_only=True)
Expand All @@ -345,6 +350,7 @@ class Outputs:

class Warning(OWWidget.Warning):
corpus_not_normalized = Msg("Use Preprocess Text to normalize corpus.")
scoring_warning = Msg("{}")

class Error(OWWidget.Error):
custom_err = Msg("{}")
Expand Down Expand Up @@ -622,6 +628,7 @@ def __setting_changed(self) -> None:
@gui.deferred
def commit(self) -> None:
self.Error.custom_err.clear()
self.Warning.scoring_warning.clear()
self.cancel()
if self.corpus is not None and self.words is not None:
scorers = self._get_active_scorers()
Expand All @@ -645,10 +652,19 @@ def commit(self) -> None:
def on_done(self, _: None) -> None:
self._send_output()

def on_partial_result(self, result: Tuple[str, str, np.ndarray]) -> None:
def on_partial_result(
self, result: Tuple[str, str, Union[np.ndarray, str]]
) -> None:
sc_method, aggregation, scores = result
self.scores[(sc_method, aggregation)] = scores
self._fill_table()
if isinstance(scores, str):
# scoring failed with error in scores variable
self.Warning.scoring_warning(
f"{SCORING_METHODS[sc_method][0]} failed: {scores}"
)
else:
# scoring successful
self.scores[(sc_method, aggregation)] = scores
self._fill_table()

def on_exception(self, ex: Exception) -> None:
self.Error.custom_err(ex)
Expand Down
Loading