Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RaBitQ implementation #4235

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

alexanderguzhva
Copy link
Contributor

This is a reference implementation of the https://arxiv.org/pdf/2405.12497

Jianyang Gao, Cheng Long, "RaBitQ: Quantizing High-Dimensional Vectors with a Theoretical Error Bound for Approximate Nearest Neighbor Search".

The goal is to correctly set up the internals using Faiss.

The following comments for the implementation:

  • The code does not include the computations for the symmetric distance, because it is absent in the original article. This can be added later, though.
  • The original RaBitQ includes random matrix rotation as a part of it, but I've decided to rely on external faiss::IndexPreTransform and faiss::RandomRotationMatrix facilities.
  • Certain features required internal changes in faiss::IndexIVF, but I did that as least invasive as possible, without breaking the backward compatibility.
  • Not sure about naming convensions, maybe certain classes and structures need to be renamed
  • METRIC_INNER_PRODUCT is supported as well
  • More unit tests are needed?
  • I did not bring any hardware-specific optimizations, bcz this is a reference implementation. Certain simdlib facilities may be added later, if needed

Here's how to use IndexRaBitQ

        ds = datasets.SyntheticDataset(...)

        index_rbq = faiss.IndexRaBitQ(ds.d, faiss.METRIC_L2)
        index_rbq.qb = 8

        # wrap with random rotations
        rrot = faiss.RandomRotationMatrix(ds.d, ds.d)
        rrot.init(rrot_seed)

        index_cand = faiss.IndexPreTransform(rrot, index_rbq)
        index_cand.train(ds.get_train())
        index_cand.add(ds.get_database())

Here's how to use IndexIVFRaBitQ

        ds = datasets.SyntheticDataset(...)

        index_flat = faiss.IndexFlat(ds.d, faiss.METRIC_L2)
        index_rbq = faiss.IndexIVFRaBitQ(index_flat, ds.d, nlist, faiss.METRIC_L2)
        index_rbq.qb = 8

        # wrap with random rotations
        rrot = faiss.RandomRotationMatrix(ds.d, ds.d)
        rrot.init(rrot_seed)

        index_cand = faiss.IndexPreTransform(rrot, index_rbq)
        index_cand.train(ds.get_train())
        index_cand.add(ds.get_database())

Copy link
Contributor

@mdouze mdouze left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Thanks for deferring the SIMD optimization to later.
I left a few comments.

faiss/IndexIVF.h Outdated
*/
virtual InvertedListScanner* get_InvertedListScanner(
bool store_pairs = false,
const IDSelector* sel = nullptr) const;

/** Get a scanner for this index (store_pairs means ignore labels).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to replace get_InvertedListScanner altogether with the version that takes IVFSearchParameters (and no sel, since IDSelector is a field of IVFSearchParameters)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about the backward compatibility? This is why I introduced get_InvertedListScanner_2. I mean that I can upgrade the method signature Faiss-wide, but what about the external code?

FAISS_ASSERT(codes != nullptr);
FAISS_ASSERT(x != nullptr);

if (n == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this test before the asserts


struct FactorsData {
// ||or - c||
float factor_0 = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not give the fields proper names? Otherwise you might as well use float factors[4]

const uint8_t* query_j = rearranged_rotated_qq.data() + j * di_8b;

// process 64-bit popcounts
unsigned long long count = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use explicity sized integer types (eg uint64_t)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do. As far as I remember, I've used unsigned long long, bcz there were problems with compilations on MacOS for these __builtin_popcount functions

float factor_2 = 0;
// ||or||^2
float factor_3 = 0;
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my implementation there are only 2 floats per database vector. Do you store additional ones for efficiency?

@alexanderguzhva
Copy link
Contributor Author

@mdouze two more comments after the discussion

First. IndexIVF::get_InvertedListScanner() signature is made into

    virtual InvertedListScanner* get_InvertedListScanner(
            bool store_pairs = false,
            const IDSelector* sel = nullptr,
            const IVFSearchParameters* params = nullptr) const;

because of the following logic that can override sel

void IndexIVF::search_preassigned(...) const {
    ...
    IDSelector* sel = params ? params->sel : nullptr;
    const IDSelectorRange* selr = dynamic_cast<const IDSelectorRange*>(sel);
    if (selr) {
        if (selr->assume_sorted) {
            sel = nullptr; // use special IDSelectorRange processing
        } else {
            selr = nullptr; // use generic processing
        }
    }
    ....
}

Please let me know your thoughts.

Second. RaBitQ uses 3 factors

struct FactorsData {
    float or_minus_c_l2sqr = 0;
    float dp_multiplier = 0;
    // this is needed to support BOTH L2 and IP on the same data
    float or_l2sqr = 0;
};

The third one or_l2sqr is needed to support both L2 and IP, similar to how PQ / SQ can use the same data for different metrics. So, these three numbers per vector are independent from a chosen metric. These three factors can be reduced to two, if we make a decision to make factors dependent from the chosen metric.
Please let me know if you'd like to have 2 or 3 factors per vector. Just double-checking.

Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants