Skip to content

Commit

Permalink
Document Embedding - add SBERT method to widget
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimozGodec committed Jun 29, 2022
1 parent c16e7ad commit 9df9fff
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 25 deletions.
60 changes: 36 additions & 24 deletions orangecontrib/text/widgets/owdocumentembedding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, Optional, Any

from AnyQt.QtCore import Qt
from AnyQt.QtWidgets import QGridLayout, QLabel, QPushButton, QStyle
from AnyQt.QtWidgets import QVBoxLayout, QPushButton, QStyle
from Orange.misc.utils.embedder_utils import EmbeddingConnectionError
from Orange.widgets import gui
from Orange.widgets.settings import Setting
Expand All @@ -13,7 +13,7 @@
LANGS_TO_ISO,
DocumentEmbedder,
)
from orangecontrib.text.widgets.utils import widgets
from orangecontrib.text.vectorization.sbert import SBERT
from orangecontrib.text.widgets.utils.owbasevectorizer import (
OWBaseVectorizer,
Vectorizer,
Expand All @@ -30,6 +30,7 @@ def _transform(self, callback):
self.new_corpus = embeddings
self.skipped_documents = skipped


class OWDocumentEmbedding(OWBaseVectorizer):
name = "Document Embedding"
description = "Document embedding using pretrained models."
Expand All @@ -40,7 +41,7 @@ class OWDocumentEmbedding(OWBaseVectorizer):
buttons_area_orientation = Qt.Vertical
settings_version = 2

Method = DocumentEmbedder
Methods = [DocumentEmbedder, SBERT]

class Outputs(OWBaseVectorizer.Outputs):
skipped = Output("Skipped documents", Corpus)
Expand All @@ -55,9 +56,9 @@ class Error(OWWidget.Error):
class Warning(OWWidget.Warning):
unsuccessful_embeddings = Msg("Some embeddings were unsuccessful.")

method = Setting(default=0)
language = Setting(default="English")
aggregator = Setting(default="Mean")
method: int = Setting(default=0)
language: str = Setting(default="English")
aggregator: str = Setting(default="Mean")

def __init__(self):
super().__init__()
Expand All @@ -69,32 +70,43 @@ def __init__(self):
self.cancel_button.setDisabled(True)

def create_configuration_layout(self):
layout = QGridLayout()
layout.setSpacing(10)

combo = widgets.ComboBox(
layout = QVBoxLayout()
rbtns = gui.radioButtons(None, self, "method", callback=self.on_change)
layout.addWidget(rbtns)

gui.appendRadioButton(rbtns, "fastText:")
ibox = gui.indentedBox(rbtns)
gui.comboBox(
ibox,
self,
"language",
items=LANGUAGES,
label="Language:",
sendSelectedValue=True, # value is actual string not index
orientation=Qt.Horizontal,
callback=self.on_change,
)
gui.comboBox(
ibox,
self,
"aggregator",
items=AGGREGATORS,
label="Aggregator:",
sendSelectedValue=True, # value is actual string not index
orientation=Qt.Horizontal,
callback=self.on_change,
)
combo.currentIndexChanged.connect(self.on_change)
layout.addWidget(QLabel("Language:"))
layout.addWidget(combo, 0, 1)

combo = widgets.ComboBox(self, "aggregator", items=AGGREGATORS)
combo.currentIndexChanged.connect(self.on_change)
layout.addWidget(QLabel("Aggregator:"))
layout.addWidget(combo, 1, 1)

gui.appendRadioButton(rbtns, "Multilingual SBERT:")
return layout

def update_method(self):
self.vectorizer = EmbeddingVectorizer(self.init_method(), self.corpus)

def init_method(self):
return self.Method(
language=LANGS_TO_ISO[self.language], aggregator=self.aggregator
)
params = dict(language=LANGS_TO_ISO[self.language], aggregator=self.aggregator)
kwargs = (params, {})[self.method]
return self.Methods[self.method](**kwargs)

@gui.deferred
def commit(self):
Expand All @@ -103,15 +115,16 @@ def commit(self):
self.cancel_button.setDisabled(False)
super().commit()

def on_done(self, _):
def on_done(self, result):
self.cancel_button.setDisabled(True)
skipped = self.vectorizer.skipped_documents
self.Outputs.skipped.send(skipped)
if skipped is not None and len(skipped) > 0:
self.Warning.unsuccessful_embeddings()
super().on_done(_)
super().on_done(result)

def on_exception(self, ex: Exception):
raise ex
self.cancel_button.setDisabled(True)
if isinstance(ex, EmbeddingConnectionError):
self.Error.no_connection()
Expand All @@ -133,7 +146,6 @@ def migrate_settings(cls, settings: Dict[str, Any], version: Optional[int]):
settings["aggregator"] = AGGREGATORS[settings["aggregator"]]



if __name__ == "__main__":
from orangewidget.utils.widgetpreview import WidgetPreview

Expand Down
18 changes: 18 additions & 0 deletions orangecontrib/text/widgets/tests/test_owdocumentembedding.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
import unittest
from unittest.mock import Mock, patch

import numpy as np
from AnyQt.QtWidgets import QComboBox
from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.tests.utils import simulate
from Orange.misc.utils.embedder_utils import EmbeddingConnectionError
from PyQt5.QtWidgets import QRadioButton

from orangecontrib.text.tests.test_documentembedder import PATCH_METHOD, make_dummy_post
from orangecontrib.text.vectorization.sbert import EMB_DIM
from orangecontrib.text.widgets.owdocumentembedding import OWDocumentEmbedding
from orangecontrib.text import Corpus


async def none_method(_, __):
return None

_response_list = str(np.arange(0, EMB_DIM, dtype=float).tolist())
SBERT_RESPONSE = f'{{"embedding": [{_response_list}]}}'.encode()


class TestOWDocumentEmbedding(WidgetTest):
def setUp(self):
Expand Down Expand Up @@ -105,6 +111,18 @@ def test_skipped_documents(self):
self.assertEqual(len(self.get_output(self.widget.Outputs.skipped)), len(self.corpus))
self.assertTrue(self.widget.Warning.unsuccessful_embeddings.is_shown())

@patch(PATCH_METHOD, make_dummy_post(SBERT_RESPONSE))
def test_sbert(self):
self.widget.findChildren(QRadioButton)[1].click()
self.widget.vectorizer.method.clear_cache()

self.send_signal("Corpus", self.corpus)
result = self.get_output(self.widget.Outputs.corpus)
self.assertIsInstance(result, Corpus)
self.assertEqual(len(self.corpus), len(result))
self.assertTupleEqual(self.corpus.domain.metas, result.domain.metas)
self.assertEqual(384, len(result.domain.attributes))


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion orangecontrib/text/widgets/utils/owbasevectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def on_change(self):
self.commit.deferred()

def send_report(self):
self.report_items(self.method.report())
self.report_items(self.vectorizer.method.report())

def create_configuration_layout(self):
raise NotImplementedError
Expand Down

0 comments on commit 9df9fff

Please sign in to comment.