Skip to content

Commit 4d92cd3

Browse files
committed
chore: Update pybind11 submodule to commit 3e9dfa2
1 parent 4c42f1a commit 4d92cd3

File tree

5 files changed

+52
-15
lines changed

5 files changed

+52
-15
lines changed

cpp/prtree.h

+16-12
Original file line numberDiff line numberDiff line change
@@ -51,36 +51,38 @@ namespace py = pybind11;
5151
template <class T>
5252
using vec = std::vector<T>;
5353

54-
template <typename Sequence >
55-
inline py::array_t<typename Sequence::value_type> as_pyarray(Sequence& seq) {
54+
template <typename Sequence>
55+
inline py::array_t<typename Sequence::value_type> as_pyarray(Sequence &seq)
56+
{
5657

5758
auto size = seq.size();
5859
auto data = seq.data();
5960
std::unique_ptr<Sequence> seq_ptr = std::make_unique<Sequence>(std::move(seq));
60-
auto capsule = py::capsule(seq_ptr.get(), [](void *p) { std::unique_ptr<Sequence>(reinterpret_cast<Sequence*>(p)); });
61+
auto capsule = py::capsule(seq_ptr.get(), [](void *p)
62+
{ std::unique_ptr<Sequence>(reinterpret_cast<Sequence *>(p)); });
6163
seq_ptr.release();
6264
return py::array(size, data, capsule);
6365
}
6466

6567
template <typename T>
66-
auto list_list_to_arrays(vec<vec<T>> out_ll){
68+
auto list_list_to_arrays(vec<vec<T>> out_ll)
69+
{
6770
vec<T> out_s;
6871
out_s.reserve(out_ll.size());
6972
std::size_t sum = 0;
70-
for (auto &&i : out_ll) {
73+
for (auto &&i : out_ll)
74+
{
7175
out_s.push_back(i.size());
7276
sum += i.size();
7377
}
7478
vec<T> out;
7579
out.reserve(sum);
76-
for(const auto &v: out_ll)
80+
for (const auto &v : out_ll)
7781
out.insert(out.end(), v.begin(), v.end());
7882

7983
return make_tuple(
80-
std::move(as_pyarray(out_s))
81-
,
82-
std::move(as_pyarray(out))
83-
);
84+
std::move(as_pyarray(out_s)),
85+
std::move(as_pyarray(out)));
8486
}
8587

8688
template <class T, size_t StaticCapacity>
@@ -242,7 +244,7 @@ class BB
242244
}
243245
for (int i = 0; i < D; ++i)
244246
{
245-
flags[i] = -minima[i] < maxima[i];
247+
flags[i] = -minima[i] <= maxima[i];
246248
}
247249
for (int i = 0; i < D; ++i)
248250
{
@@ -1291,7 +1293,8 @@ class PRTree
12911293
return out;
12921294
}
12931295

1294-
auto find_all_array(const py::array_t<float> &x){
1296+
auto find_all_array(const py::array_t<float> &x)
1297+
{
12951298
return list_list_to_arrays(std::move(find_all(x)));
12961299
}
12971300

@@ -1334,6 +1337,7 @@ class PRTree
13341337
};
13351338

13361339
bfs<T, B, D>(std::move(find_func), flat_tree, target);
1340+
std::sort(out.begin(), out.end());
13371341
return out;
13381342
}
13391343

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
numpy>=1.16
1+
numpy>=1.16,<2.0
22
pybind11; platform_machine != "x86_64" and platform_machine != "amd64" and platform_machine != "AMD64" and sys_platform == 'darwin' # for m1 mac
33
cmake; platform_machine != "x86_64" and platform_machine != "amd64" and platform_machine != "AMD64" and sys_platform == 'darwin' # for m1 mac

setup.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from setuptools import Extension, find_packages, setup
1010
from setuptools.command.build_ext import build_ext
1111

12-
version = "v0.6.0"
12+
version = "v0.6.1"
1313

1414
sys.path.append("./tests")
1515

@@ -109,6 +109,10 @@ def build_extension(self, ext):
109109
classifiers=[
110110
"License :: OSI Approved :: MIT License",
111111
"Programming Language :: Python :: 3",
112+
"Programming Language :: Python :: 3.8",
112113
"Programming Language :: Python :: 3.9",
114+
"Programming Language :: Python :: 3.10",
115+
"Programming Language :: Python :: 3.11",
116+
"Programming Language :: Python :: 3.12",
113117
],
114118
)

tests/test_PRTree.py

+29
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,32 @@ def test_obj(seed, PRTree, dim, tmp_path):
139139
idx = prtree.query(q)
140140
return_obj = prtree2.query(q, return_obj=True)
141141
assert set(return_obj) == set([obj[i] for i in idx])
142+
143+
144+
def test_readme():
145+
idxes = np.array([1, 2])
146+
rects = np.array([[0.0, 0.0, 1.0, 0.5], [1.0, 1.5, 1.2, 3.0]])
147+
prtree = PRTree2D(idxes, rects)
148+
149+
# batch query
150+
q = np.array([[0.5, 0.2, 0.6, 0.3], [0.8, 0.5, 1.5, 3.5]])
151+
result = prtree.batch_query(q)
152+
assert result == [[1], [1, 2]]
153+
154+
# Insert
155+
prtree.insert(3, [1.0, 1.0, 2.0, 2.0])
156+
q = np.array([[0.5, 0.2, 0.6, 0.3], [0.8, 0.5, 1.5, 3.5]])
157+
result = prtree.batch_query(q)
158+
assert result == [[1], [1, 2, 3]]
159+
160+
# Erase
161+
prtree.erase(2)
162+
result = prtree.batch_query(q)
163+
assert result == [[1], [1, 3]]
164+
165+
# non-batch query
166+
assert prtree.query([0.5, 0.5, 1.0, 1.0]) == [1, 3]
167+
168+
# point query
169+
assert prtree.query([0.5, 0.5]) == [1]
170+
assert prtree.query(0.5, 0.5) == [1]

third/pybind11

Submodule pybind11 updated 212 files

0 commit comments

Comments
 (0)