Skip to content

Commit 8eecdb6

Browse files
MB-63643: Fix missing num_threads clauses (#44)
- In some places the #pragma omp statements were missing the num_threads clause, leading to the global OMP Config being ignored here --------- Co-authored-by: Abhinav Dangeti <abhinav@couchbase.com>
1 parent 224acef commit 8eecdb6

10 files changed

+25
-25
lines changed

faiss/IndexFastScan.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ void IndexFastScan::search_dispatch_implem(
323323
}
324324
} else {
325325
// explicitly slice over threads
326-
#pragma omp parallel for num_threads(nt)
326+
#pragma omp parallel for num_threads(num_omp_threads)
327327
for (int slice = 0; slice < nt; slice++) {
328328
idx_t i0 = n * slice / nt;
329329
idx_t i1 = n * (slice + 1) / nt;

faiss/IndexIVF.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1006,7 +1006,7 @@ void IndexIVF::search_and_reconstruct(
10061006
labels,
10071007
true /* store_pairs */,
10081008
params);
1009-
#pragma omp parallel for if (n * k > 1000)
1009+
#pragma omp parallel for if (n * k > 1000) num_threads(num_omp_threads)
10101010
for (idx_t ij = 0; ij < n * k; ij++) {
10111011
idx_t key = labels[ij];
10121012
float* reconstructed = recons + ij * d;
@@ -1068,7 +1068,7 @@ void IndexIVF::search_and_return_codes(
10681068
code_size_1 += coarse_code_size();
10691069
}
10701070

1071-
#pragma omp parallel for if (n * k > 1000)
1071+
#pragma omp parallel for if (n * k > 1000) num_threads(num_omp_threads)
10721072
for (idx_t ij = 0; ij < n * k; ij++) {
10731073
idx_t key = labels[ij];
10741074
uint8_t* code1 = codes + ij * code_size_1;

faiss/IndexIVFFastScan.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ void IndexIVFFastScan::range_search_dispatch_implem(
640640
} else {
641641
// explicitly slice over threads
642642
int nslice = compute_search_nslice(this, n, cq.nprobe);
643-
#pragma omp parallel
643+
#pragma omp parallel num_threads(num_omp_threads)
644644
{
645645
RangeSearchPartialResult pres(&rres);
646646

faiss/impl/PolysemousTraining.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,7 @@ void PolysemousTraining::optimize_reproduce_distances(
779779
nt);
780780
}
781781

782-
#pragma omp parallel for num_threads(nt)
782+
#pragma omp parallel for num_threads(num_omp_threads)
783783
for (int m = 0; m < pq.M; m++) {
784784
std::vector<double> dis_table;
785785

faiss/impl/ProductQuantizer.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ void ProductQuantizer::decode(const uint8_t* code, float* x) const {
313313
}
314314

315315
void ProductQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
316-
#pragma omp parallel for if (n > 100)
316+
#pragma omp parallel for if (n > 100) num_threads(num_omp_threads)
317317
for (int64_t i = 0; i < n; i++) {
318318
this->decode(code + code_size * i, x + d * i);
319319
}

faiss/impl/residual_quantizer_encode_steps.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ void beam_search_encode_step(
275275
}
276276
InterruptCallback::check();
277277

278-
#pragma omp parallel for if (n > 100)
278+
#pragma omp parallel for if (n > 100) num_threads(num_omp_threads)
279279
for (int64_t i = 0; i < n; i++) {
280280
const int32_t* codes_i = codes + i * m * beam_size;
281281
int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
@@ -399,7 +399,7 @@ void beam_search_encode_step_tab(
399399
{
400400
FAISS_THROW_IF_NOT(ldc >= K);
401401

402-
#pragma omp parallel for if (n > 100) schedule(dynamic)
402+
#pragma omp parallel for if (n > 100) schedule(dynamic) num_threads(num_omp_threads)
403403
for (int64_t i = 0; i < n; i++) {
404404
std::vector<float> cent_distances(beam_size * K);
405405
std::vector<float> cd_common(K);

faiss/utils/distances.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ void exhaustive_inner_product_seq(
146146

147147
FAISS_ASSERT(use_sel == (sel != nullptr));
148148

149-
#pragma omp parallel num_threads(nt)
149+
#pragma omp parallel num_threads(num_omp_threads)
150150
{
151151
SingleResultHandler resi(res);
152152
#pragma omp for
@@ -183,7 +183,7 @@ void exhaustive_L2sqr_seq(
183183

184184
FAISS_ASSERT(use_sel == (sel != nullptr));
185185

186-
#pragma omp parallel num_threads(nt)
186+
#pragma omp parallel num_threads(num_omp_threads)
187187
{
188188
SingleResultHandler resi(res);
189189
#pragma omp for

faiss/utils/hamming.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ void hamming_range_search(
293293
int radius,
294294
size_t code_size,
295295
RangeSearchResult* res) {
296-
#pragma omp parallel
296+
#pragma omp parallel num_threads(num_omp_threads)
297297
{
298298
RangeSearchPartialResult pres(res);
299299

@@ -687,7 +687,7 @@ void pack_bitstrings(
687687
uint8_t* packed,
688688
size_t code_size) {
689689
FAISS_THROW_IF_NOT(code_size >= (M * nbit + 7) / 8);
690-
#pragma omp parallel for if (n > 1000)
690+
#pragma omp parallel for if (n > 1000) num_threads(num_omp_threads)
691691
for (int64_t i = 0; i < n; i++) {
692692
const int32_t* in = unpacked + i * M;
693693
uint8_t* out = packed + i * code_size;
@@ -710,7 +710,7 @@ void pack_bitstrings(
710710
totbit += nbit[j];
711711
}
712712
FAISS_THROW_IF_NOT(code_size >= (totbit + 7) / 8);
713-
#pragma omp parallel for if (n > 1000)
713+
#pragma omp parallel for if (n > 1000) num_threads(num_omp_threads)
714714
for (int64_t i = 0; i < n; i++) {
715715
const int32_t* in = unpacked + i * M;
716716
uint8_t* out = packed + i * code_size;
@@ -729,7 +729,7 @@ void unpack_bitstrings(
729729
size_t code_size,
730730
int32_t* unpacked) {
731731
FAISS_THROW_IF_NOT(code_size >= (M * nbit + 7) / 8);
732-
#pragma omp parallel for if (n > 1000)
732+
#pragma omp parallel for if (n > 1000) num_threads(num_omp_threads)
733733
for (int64_t i = 0; i < n; i++) {
734734
const uint8_t* in = packed + i * code_size;
735735
int32_t* out = unpacked + i * M;
@@ -752,7 +752,7 @@ void unpack_bitstrings(
752752
totbit += nbit[j];
753753
}
754754
FAISS_THROW_IF_NOT(code_size >= (totbit + 7) / 8);
755-
#pragma omp parallel for if (n > 1000)
755+
#pragma omp parallel for if (n > 1000) num_threads(num_omp_threads)
756756
for (int64_t i = 0; i < n; i++) {
757757
const uint8_t* in = packed + i * code_size;
758758
int32_t* out = unpacked + i * M;

faiss/utils/sorting.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ void parallel_merge(
6161
s2s[nt - 1].i1 = s2.i1;
6262

6363
// not sure parallel actually helps here
64-
#pragma omp parallel for num_threads(nt)
64+
#pragma omp parallel for num_threads(num_omp_threads)
6565
for (int t = 0; t < nt; t++) {
6666
s1s[t].i0 = s1.i0 + s1.len() * t / nt;
6767
s1s[t].i1 = s1.i0 + s1.len() * (t + 1) / nt;
@@ -93,7 +93,7 @@ void parallel_merge(
9393
assert(sws[nt - 1].i1 == s1.i1);
9494

9595
// do the actual merging
96-
#pragma omp parallel for num_threads(nt)
96+
#pragma omp parallel for num_threads(num_omp_threads)
9797
for (int t = 0; t < nt; t++) {
9898
SegmentS sw = sws[t];
9999
SegmentS s1t = s1s[t];
@@ -176,7 +176,7 @@ void fvec_argsort_parallel(size_t n, const float* vals, size_t* perm) {
176176
int sub_nt = nseg % 2 == 0 ? nt : nt - 1;
177177
int sub_nseg1 = nseg / 2;
178178

179-
#pragma omp parallel for num_threads(nseg1)
179+
#pragma omp parallel for num_threads(num_omp_threads)
180180
for (int s = 0; s < nseg; s += 2) {
181181
if (s + 1 == nseg) { // otherwise isolated segment
182182
memcpy(permB + segs[s].i0,
@@ -257,7 +257,7 @@ void bucket_sort_parallel(
257257
int64_t* perm,
258258
int nt_in) {
259259
memset(lims, 0, sizeof(*lims) * (vmax + 1));
260-
#pragma omp parallel num_threads(nt_in)
260+
#pragma omp parallel num_threads(num_omp_threads)
261261
{
262262
int nt = omp_get_num_threads(); // might be different from nt_in
263263
int rank = omp_get_thread_num();
@@ -483,7 +483,7 @@ void bucket_sort_inplace_parallel(
483483
nbucket); // DON'T use std::vector<bool> that cannot be accessed
484484
// safely from multiple threads!!!
485485

486-
#pragma omp parallel num_threads(nt_in)
486+
#pragma omp parallel num_threads(num_omp_threads)
487487
{
488488
int nt = omp_get_num_threads(); // might be different from nt_in (?)
489489
int rank = omp_get_thread_num();
@@ -709,7 +709,7 @@ inline int64_t hash_function(int64_t x) {
709709

710710
void hashtable_int64_to_int64_init(int log2_capacity, int64_t* tab) {
711711
size_t capacity = (size_t)1 << log2_capacity;
712-
#pragma omp parallel for
712+
#pragma omp parallel for num_threads(num_omp_threads)
713713
for (int64_t i = 0; i < capacity; i++) {
714714
tab[2 * i] = -1;
715715
tab[2 * i + 1] = -1;
@@ -729,7 +729,7 @@ void hashtable_int64_to_int64_add(
729729
int log2_nbucket = log2_capacity_to_log2_nbucket(log2_capacity);
730730
size_t nbucket = (size_t)1 << log2_nbucket;
731731

732-
#pragma omp parallel for
732+
#pragma omp parallel for num_threads(num_omp_threads)
733733
for (int64_t i = 0; i < n; i++) {
734734
hk[i] = hash_function(keys[i]) & mask;
735735
bucket_no[i] = hk[i] >> (log2_capacity - log2_nbucket);
@@ -746,7 +746,7 @@ void hashtable_int64_to_int64_add(
746746
omp_get_max_threads());
747747

748748
int num_errors = 0;
749-
#pragma omp parallel for reduction(+ : num_errors)
749+
#pragma omp parallel for reduction(+ : num_errors) num_threads(num_omp_threads)
750750
for (int64_t bucket = 0; bucket < nbucket; bucket++) {
751751
size_t k0 = bucket << (log2_capacity - log2_nbucket);
752752
size_t k1 = (bucket + 1) << (log2_capacity - log2_nbucket);
@@ -793,7 +793,7 @@ void hashtable_int64_to_int64_lookup(
793793
int64_t mask = capacity - 1;
794794
int log2_nbucket = log2_capacity_to_log2_nbucket(log2_capacity);
795795

796-
#pragma omp parallel for
796+
#pragma omp parallel for num_threads(num_omp_threads)
797797
for (int64_t i = 0; i < n; i++) {
798798
int64_t k = keys[i];
799799
int64_t hk = hash_function(k) & mask;

faiss/utils/utils.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ void bvecs_checksum(size_t n, size_t d, const uint8_t* a, uint64_t* cs) {
455455
// so below codes only accept n <= std::numeric_limits<ssize_t>::max()
456456
using ssize_t = std::make_signed<std::size_t>::type;
457457
const ssize_t size = n;
458-
#pragma omp parallel for if (size > 1000)
458+
#pragma omp parallel for if (size > 1000) num_threads(num_omp_threads)
459459
for (ssize_t i_ = 0; i_ < size; i_++) {
460460
const auto i = static_cast<std::size_t>(i_);
461461
cs[i] = bvec_checksum(d, a + i * d);

0 commit comments

Comments
 (0)