Skip to content

Commit 00032d0

Browse files
Fix: Wolfram bindings (#437)
Co-authored-by: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com>
1 parent 032556d commit 00032d0

File tree

2 files changed

+74
-23
lines changed

2 files changed

+74
-23
lines changed

wolfram/CMakeLists.txt

+44-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,52 @@
1+
cmake_minimum_required(VERSION 3.15.0)
12

2-
if(NOT WOLFRAM_PATH)
3+
project(usearch)
4+
5+
if (NOT SYSTEMID)
6+
# set system id and build platform
7+
set(BITNESS 32)
8+
if (CMAKE_SIZEOF_VOID_P EQUAL 8)
9+
set(BITNESS 64)
10+
endif ()
11+
12+
set(SYSTEMID NOTFOUND)
13+
14+
# Determine the current machine's systemid.
15+
if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND BITNESS EQUAL 64)
16+
if (CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
17+
set(SYSTEMID Linux-ARM64)
18+
else ()
19+
set(SYSTEMID Linux-x86-64)
20+
endif ()
21+
elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND BITNESS EQUAL 64)
22+
if (CMAKE_SYSTEM_PROCESSOR MATCHES "arm*")
23+
set(SYSTEMID MacOSX-ARM64)
24+
else ()
25+
set(SYSTEMID MacOSX-x86-64)
26+
endif ()
27+
elseif (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND BITNESS EQUAL 64)
28+
if (_MSVC_C_ARCHITECTURE_FAMILY STREQUAL "ARM64")
29+
set(SYSTEMID Windows-ARM64)
30+
else ()
31+
set(SYSTEMID Windows-x86-64)
32+
endif ()
33+
endif ()
34+
35+
if (NOT SYSTEMID)
36+
message(FATAL_ERROR "Unable to determine System ID.")
37+
endif ()
38+
endif ()
39+
40+
if (NOT WOLFRAM_PATH)
341
set(WOLFRAM_PATH "/usr/local/Wolfram/Mathematica")
4-
endif()
42+
endif ()
543

6-
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${WOLFRAM_PATH}/SystemFiles/Links/WSTP/DeveloperKit/Linux-x86-64/CompilerAdditions/")
44+
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH}
45+
"${WOLFRAM_PATH}/SystemFiles/Links/WSTP/DeveloperKit/${SYSTEMID}/CompilerAdditions/"
46+
)
747
include(WSTP)
848
include_directories("${WOLFRAM_PATH}/SystemFiles/IncludeFiles/C/")
9-
link_directories("${WOLFRAM_PATH}/SystemFiles/Libraries/Linux-x86-64/")
49+
link_directories("${WOLFRAM_PATH}/SystemFiles/Libraries/${SYSTEMID}/")
1050

1151
add_library(usearchWFM SHARED lib.cpp)
1252

wolfram/lib.cpp

+30-19
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,37 @@
44
using namespace unum::usearch;
55

66
using distance_t = distance_punned_t;
7+
using metric_t = metric_punned_t;
78
using index_t = index_dense_t;
8-
using vector_view_t = span_gt<float>;
9+
using vector_view_t = span_gt<double>;
910

1011
using add_result_t = typename index_t::add_result_t;
1112
using search_result_t = typename index_t::search_result_t;
1213
using vector_key_t = typename index_t::vector_key_t;
14+
using dense_search_result_t = typename index_t::search_result_t;
1315

1416
EXTERN_C DLLEXPORT int WolframLibrary_initialize(WolframLibraryData libData) { return LIBRARY_NO_ERROR; }
1517
EXTERN_C DLLEXPORT void WolframLibrary_uninitialize(WolframLibraryData libData) { return; }
1618

1719
EXTERN_C DLLEXPORT int IndexCreate(WolframLibraryData libData, mint Argc, MArgument* Args, MArgument Res) {
18-
index_config_t config;
1920
char* quantization_cstr = nullptr;
2021
char* metric_cstr = nullptr;
2122
try {
2223
quantization_cstr = MArgument_getUTF8String(Args[1]);
2324
metric_cstr = MArgument_getUTF8String(Args[0]);
2425
std::size_t dimensions = static_cast<std::size_t>(MArgument_getInteger(Args[2]));
2526
std::size_t capacity = static_cast<std::size_t>(MArgument_getInteger(Args[3]));
26-
config.connectivity = static_cast<std::size_t>(MArgument_getInteger(Args[4]));
27-
config.expansion_add = static_cast<std::size_t>(MArgument_getInteger(Args[5]));
28-
config.expansion_search = static_cast<std::size_t>(MArgument_getInteger(Args[6]));
27+
std::size_t connectivity = static_cast<std::size_t>(MArgument_getInteger(Args[4]));
28+
std::size_t expansion_add = static_cast<std::size_t>(MArgument_getInteger(Args[5]));
29+
std::size_t expansion_search = static_cast<std::size_t>(MArgument_getInteger(Args[6]));
30+
31+
index_dense_config_t config(connectivity, expansion_add, expansion_search);
2932

3033
scalar_kind_t quantization = scalar_kind_from_name(quantization_cstr, std::strlen(quantization_cstr));
3134
metric_kind_t metric_kind = metric_from_name(metric_cstr, std::strlen(metric_cstr));
32-
index_t index = make_punned<index_t>(metric_kind, dimensions, quantization, config);
35+
metric_t metric = metric_t::builtin(dimensions, metric_kind, quantization);
36+
index_t index = index_t::make(metric, config);
37+
3338
index.reserve(capacity);
3439

3540
index_t* result_ptr = new index_t(std::move(index));
@@ -115,12 +120,12 @@ EXTERN_C DLLEXPORT int IndexCapacity(WolframLibraryData libData, mint Argc, MArg
115120
EXTERN_C DLLEXPORT int IndexAdd(WolframLibraryData libData, mint Argc, MArgument* Args, MArgument Res) {
116121
char* path_cstr = nullptr;
117122
index_t* c_ptr = (index_t*)MArgument_getUTF8String(Args[0]);
118-
float* vector_data = nullptr;
123+
double* vector_data = nullptr;
119124
try {
120125
int key = MArgument_getInteger(Args[1]);
121126
MTensor tens = MArgument_getMTensor(Args[2]);
122127
std::size_t len = libData->MTensor_getFlattenedLength(tens);
123-
vector_data = (float*)libData->MTensor_getRealData(tens);
128+
vector_data = (double*)libData->MTensor_getRealData(tens);
124129
vector_view_t vector_span = vector_view_t{vector_data, len};
125130
c_ptr->add(key, vector_span);
126131
} catch (...) {
@@ -132,24 +137,30 @@ EXTERN_C DLLEXPORT int IndexAdd(WolframLibraryData libData, mint Argc, MArgument
132137
EXTERN_C DLLEXPORT int IndexSearch(WolframLibraryData libData, mint Argc, MArgument* Args, MArgument Res) {
133138
index_t* c_ptr = (index_t*)MArgument_getUTF8String(Args[0]);
134139
MTensor matches;
135-
mint dims[] = {1};
136140
int wanted = MArgument_getInteger(Args[2]);
137-
float* vector_data = nullptr;
138-
int* matches_data = nullptr;
139-
std::size_t found = 0;
141+
double* vector_data = nullptr;
142+
vector_key_t* matches_data = nullptr;
140143

141144
try {
142-
libData->MTensor_new(MType_Integer, 1, dims, &matches);
143-
144145
MTensor tens = MArgument_getMTensor(Args[1]);
145146
std::size_t len = libData->MTensor_getFlattenedLength(tens);
146-
vector_data = (float*)libData->MTensor_getRealData(tens);
147+
vector_data = (double*)libData->MTensor_getRealData(tens);
147148
vector_view_t vector_span = vector_view_t{vector_data, len};
149+
matches_data = (vector_key_t*)std::malloc(sizeof(vector_key_t) * wanted);
150+
dense_search_result_t found = c_ptr->search(vector_span, static_cast<std::size_t>(wanted));
151+
152+
if (!found) {
153+
found.error.release();
154+
return LIBRARY_FUNCTION_ERROR;
155+
}
156+
157+
std::size_t count = found.dump_to(matches_data);
158+
159+
mint dims = static_cast<mint>(count);
160+
libData->MTensor_new(MType_Integer, 1, &dims, &matches);
148161

149-
matches_data = (int*)std::malloc(sizeof(int) * wanted);
150-
found = c_ptr->search(vector_span, static_cast<std::size_t>(wanted), matches_data, nullptr);
151-
for (mint i = 0; i < found; i++)
152-
libData->MTensor_setInteger(matches, &i, matches_data[i]);
162+
for (mint i = 1; i <= (mint)count; i++)
163+
libData->MTensor_setInteger(matches, &i, static_cast<mint>(matches_data[i - 1]));
153164

154165
MArgument_setMTensor(Res, matches);
155166
} catch (...) {

0 commit comments

Comments
 (0)