1
+ import base64
2
+ import json
1
3
import unittest
2
- from unittest . mock import patch
3
- from collections . abc import Iterator
4
+ import zlib
5
+ from unittest . mock import patch , ANY
4
6
import asyncio
5
7
6
8
from orangecontrib .text .vectorization .sbert import SBERT , EMB_DIM
7
9
from orangecontrib .text import Corpus
8
10
9
11
PATCH_METHOD = 'httpx.AsyncClient.post'
10
- RESPONSE = [
11
- f'{{ "embedding": { [i ] * EMB_DIM } }}' . encode ( )
12
- for i in range ( 9 )
13
- ]
14
-
12
+ RESPONSES = {
13
+ t : [i ] * EMB_DIM for i , t in enumerate ( Corpus . from_file ( "deerwester" ). documents )
14
+ }
15
+ RESPONSE_NONE = RESPONSES . copy ()
16
+ RESPONSE_NONE [ list ( RESPONSE_NONE . keys ())[ - 1 ]] = None
15
17
IDEAL_RESPONSE = [[i ] * EMB_DIM for i in range (9 )]
16
18
17
19
18
20
class DummyResponse :
19
-
20
21
def __init__ (self , content ):
21
22
self .content = content
22
23
23
24
24
- def make_dummy_post (response , sleep = 0 ):
25
+ def _decompress_text (instance ):
26
+ return zlib .decompress (base64 .b64decode (instance .encode ("utf-8" ))).decode ("utf-8" )
27
+
28
+
29
+ def make_dummy_post (responses , sleep = 0 ):
25
30
@staticmethod
26
31
async def dummy_post (url , headers , data = None , content = None ):
27
32
assert data or content
28
33
await asyncio .sleep (sleep )
29
- return DummyResponse (
30
- content = next (response ) if isinstance (response , Iterator ) else response
31
- )
34
+ data = json .loads (content .decode ("utf-8" , "replace" ))
35
+ data_ = data if isinstance (data , list ) else [data ]
36
+ texts = [_decompress_text (instance ) for instance in data_ ]
37
+ responses_ = [responses [t ] for t in texts ]
38
+ r = {"embedding" : responses_ if isinstance (data , list ) else responses_ [0 ]}
39
+ return DummyResponse (content = json .dumps (r ).encode ("utf-8" ))
32
40
return dummy_post
33
41
34
42
@@ -51,25 +59,25 @@ def test_empty_corpus(self, mock):
51
59
dict ()
52
60
)
53
61
54
- @patch (PATCH_METHOD , make_dummy_post (iter ( RESPONSE ) ))
62
+ @patch (PATCH_METHOD , make_dummy_post (RESPONSES ))
55
63
def test_success (self ):
56
64
result = self .sbert (self .corpus .documents )
57
65
self .assertEqual (result , IDEAL_RESPONSE )
58
66
59
- @patch (PATCH_METHOD , make_dummy_post (iter ( RESPONSE [: - 1 ] + [ None ] * 3 ) ))
67
+ @patch (PATCH_METHOD , make_dummy_post (RESPONSE_NONE ))
60
68
def test_none_result (self ):
61
69
result = self .sbert (self .corpus .documents )
62
70
self .assertEqual (result , IDEAL_RESPONSE [:- 1 ] + [None ])
63
71
64
- @patch (PATCH_METHOD , make_dummy_post (iter ( RESPONSE ) ))
72
+ @patch (PATCH_METHOD , make_dummy_post (RESPONSES ))
65
73
def test_transform (self ):
66
74
res , skipped = self .sbert .transform (self .corpus )
67
75
self .assertIsNone (skipped )
68
76
self .assertEqual (len (self .corpus ), len (res ))
69
77
self .assertTupleEqual (self .corpus .domain .metas , res .domain .metas )
70
78
self .assertEqual (384 , len (res .domain .attributes ))
71
79
72
- @patch (PATCH_METHOD , make_dummy_post (iter ( RESPONSE [: - 1 ] + [ None ] * 3 ) ))
80
+ @patch (PATCH_METHOD , make_dummy_post (RESPONSE_NONE ))
73
81
def test_transform_skipped (self ):
74
82
res , skipped = self .sbert .transform (self .corpus )
75
83
self .assertEqual (len (self .corpus ) - 1 , len (res ))
@@ -80,6 +88,29 @@ def test_transform_skipped(self):
80
88
self .assertTupleEqual (self .corpus .domain .metas , skipped .domain .metas )
81
89
self .assertEqual (0 , len (skipped .domain .attributes ))
82
90
91
+ @patch (PATCH_METHOD , make_dummy_post (RESPONSES ))
92
+ def test_batches_success (self ):
93
+ for i in range (1 , 11 ): # try different batch sizes
94
+ result = self .sbert .embed_batches (self .corpus .documents , i )
95
+ self .assertEqual (result , IDEAL_RESPONSE )
96
+
97
+ @patch (PATCH_METHOD , make_dummy_post (RESPONSE_NONE ))
98
+ def test_batches_none_result (self ):
99
+ for i in range (1 , 11 ): # try different batch sizes
100
+ result = self .sbert .embed_batches (self .corpus .documents , i )
101
+ self .assertEqual (result , IDEAL_RESPONSE [:- 1 ] + [None ])
102
+
103
+ @patch ("orangecontrib.text.vectorization.sbert._ServerCommunicator.embedd_data" )
104
+ def test_reordered (self , mock ):
105
+ """Test that texts are reordered according to their length"""
106
+ self .sbert (self .corpus .documents )
107
+ mock .assert_called_with (
108
+ tuple (sorted (self .corpus .documents , key = len , reverse = True )), callback = ANY
109
+ )
110
+
111
+ self .sbert ([["1" , "2" ], ["4" , "5" , "6" ], ["0" ]])
112
+ mock .assert_called_with ((["4" , "5" , "6" ], ["1" , "2" ], ["0" ]), callback = ANY )
113
+
83
114
84
115
if __name__ == "__main__" :
85
116
unittest .main ()
0 commit comments