4
4
using namespace unum ::usearch;
5
5
6
6
using distance_t = distance_punned_t ;
7
+ using metric_t = metric_punned_t ;
7
8
using index_t = index_dense_t ;
8
- using vector_view_t = span_gt<float >;
9
+ using vector_view_t = span_gt<double >;
9
10
10
11
using add_result_t = typename index_t ::add_result_t ;
11
12
using search_result_t = typename index_t ::search_result_t ;
12
13
using vector_key_t = typename index_t ::vector_key_t ;
14
+ using dense_search_result_t = typename index_t ::search_result_t ;
13
15
14
16
EXTERN_C DLLEXPORT int WolframLibrary_initialize (WolframLibraryData libData) { return LIBRARY_NO_ERROR; }
15
17
EXTERN_C DLLEXPORT void WolframLibrary_uninitialize (WolframLibraryData libData) { return ; }
16
18
17
19
EXTERN_C DLLEXPORT int IndexCreate (WolframLibraryData libData, mint Argc, MArgument* Args, MArgument Res) {
18
- index_config_t config;
19
20
char * quantization_cstr = nullptr ;
20
21
char * metric_cstr = nullptr ;
21
22
try {
22
23
quantization_cstr = MArgument_getUTF8String (Args[1 ]);
23
24
metric_cstr = MArgument_getUTF8String (Args[0 ]);
24
25
std::size_t dimensions = static_cast <std::size_t >(MArgument_getInteger (Args[2 ]));
25
26
std::size_t capacity = static_cast <std::size_t >(MArgument_getInteger (Args[3 ]));
26
- config.connectivity = static_cast <std::size_t >(MArgument_getInteger (Args[4 ]));
27
- config.expansion_add = static_cast <std::size_t >(MArgument_getInteger (Args[5 ]));
28
- config.expansion_search = static_cast <std::size_t >(MArgument_getInteger (Args[6 ]));
27
+ std::size_t connectivity = static_cast <std::size_t >(MArgument_getInteger (Args[4 ]));
28
+ std::size_t expansion_add = static_cast <std::size_t >(MArgument_getInteger (Args[5 ]));
29
+ std::size_t expansion_search = static_cast <std::size_t >(MArgument_getInteger (Args[6 ]));
30
+
31
+ index_dense_config_t config (connectivity, expansion_add, expansion_search);
29
32
30
33
scalar_kind_t quantization = scalar_kind_from_name (quantization_cstr, std::strlen (quantization_cstr));
31
34
metric_kind_t metric_kind = metric_from_name (metric_cstr, std::strlen (metric_cstr));
32
- index_t index = make_punned<index_t >(metric_kind, dimensions, quantization, config);
35
+ metric_t metric = metric_t::builtin (dimensions, metric_kind, quantization);
36
+ index_t index = index_t::make (metric, config);
37
+
33
38
index .reserve (capacity);
34
39
35
40
index_t * result_ptr = new index_t (std::move (index ));
@@ -115,12 +120,12 @@ EXTERN_C DLLEXPORT int IndexCapacity(WolframLibraryData libData, mint Argc, MArg
115
120
EXTERN_C DLLEXPORT int IndexAdd (WolframLibraryData libData, mint Argc, MArgument* Args, MArgument Res) {
116
121
char * path_cstr = nullptr ;
117
122
index_t * c_ptr = (index_t *)MArgument_getUTF8String (Args[0 ]);
118
- float * vector_data = nullptr ;
123
+ double * vector_data = nullptr ;
119
124
try {
120
125
int key = MArgument_getInteger (Args[1 ]);
121
126
MTensor tens = MArgument_getMTensor (Args[2 ]);
122
127
std::size_t len = libData->MTensor_getFlattenedLength (tens);
123
- vector_data = (float *)libData->MTensor_getRealData (tens);
128
+ vector_data = (double *)libData->MTensor_getRealData (tens);
124
129
vector_view_t vector_span = vector_view_t {vector_data, len};
125
130
c_ptr->add (key, vector_span);
126
131
} catch (...) {
@@ -132,24 +137,30 @@ EXTERN_C DLLEXPORT int IndexAdd(WolframLibraryData libData, mint Argc, MArgument
132
137
EXTERN_C DLLEXPORT int IndexSearch (WolframLibraryData libData, mint Argc, MArgument* Args, MArgument Res) {
133
138
index_t * c_ptr = (index_t *)MArgument_getUTF8String (Args[0 ]);
134
139
MTensor matches;
135
- mint dims[] = {1 };
136
140
int wanted = MArgument_getInteger (Args[2 ]);
137
- float * vector_data = nullptr ;
138
- int * matches_data = nullptr ;
139
- std::size_t found = 0 ;
141
+ double * vector_data = nullptr ;
142
+ vector_key_t * matches_data = nullptr ;
140
143
141
144
try {
142
- libData->MTensor_new (MType_Integer, 1 , dims, &matches);
143
-
144
145
MTensor tens = MArgument_getMTensor (Args[1 ]);
145
146
std::size_t len = libData->MTensor_getFlattenedLength (tens);
146
- vector_data = (float *)libData->MTensor_getRealData (tens);
147
+ vector_data = (double *)libData->MTensor_getRealData (tens);
147
148
vector_view_t vector_span = vector_view_t {vector_data, len};
149
+ matches_data = (vector_key_t *)std::malloc (sizeof (vector_key_t ) * wanted);
150
+ dense_search_result_t found = c_ptr->search (vector_span, static_cast <std::size_t >(wanted));
151
+
152
+ if (!found) {
153
+ found.error .release ();
154
+ return LIBRARY_FUNCTION_ERROR;
155
+ }
156
+
157
+ std::size_t count = found.dump_to (matches_data);
158
+
159
+ mint dims = static_cast <mint>(count);
160
+ libData->MTensor_new (MType_Integer, 1 , &dims, &matches);
148
161
149
- matches_data = (int *)std::malloc (sizeof (int ) * wanted);
150
- found = c_ptr->search (vector_span, static_cast <std::size_t >(wanted), matches_data, nullptr );
151
- for (mint i = 0 ; i < found; i++)
152
- libData->MTensor_setInteger (matches, &i, matches_data[i]);
162
+ for (mint i = 1 ; i <= (mint)count; i++)
163
+ libData->MTensor_setInteger (matches, &i, static_cast <mint>(matches_data[i - 1 ]));
153
164
154
165
MArgument_setMTensor (Res, matches);
155
166
} catch (...) {
0 commit comments