Skip to content

Commit 789d325

Browse files
andrzejnovakhenryiiipre-commit-ci[bot]
authored
feat: allow rebinning by passing edges or a new axis (#977)
* feat: add rebin by edges/axis, fix flow bins * chore: linting * fix: handle flow bins on rebin, add tests * refactor: move rebin logic out to tag.py * chore: pass mypy * refactor: cleanup rebin arg names * Apply suggestions from code review * Update tag.py * style: pre-commit fixes --------- Co-authored-by: Henry Schreiner <HenrySchreinerIII@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a6876bb commit 789d325

File tree

3 files changed

+95
-11
lines changed

3 files changed

+95
-11
lines changed

src/boost_histogram/histogram.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,7 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:
881881
reduced: CppHistogram | None = None
882882

883883
# Compute needed slices and projections
884-
for i, ind in enumerate(indexes):
884+
for i, ind in enumerate(indexes): # pylint: disable=too-many-nested-blocks
885885
if isinstance(ind, SupportsIndex):
886886
pick_each[i] = ind.__index__() + (
887887
1 if self.axes[i].traits.underflow else 0
@@ -967,14 +967,26 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:
967967

968968
new_reduced = reduced.__class__(axes)
969969
new_view = new_reduced.view(flow=True)
970-
971-
j = 1
970+
j = 0
971+
new_j_base = 0
972+
if self.axes[i].traits.underflow:
973+
groups.insert(0, 1)
974+
else:
975+
new_j_base = 1
976+
if self.axes[i].traits.overflow:
977+
groups.append(1)
972978
for new_j, group in enumerate(groups):
973979
for _ in range(group):
974980
pos = [slice(None)] * (i)
975-
new_view[(*pos, new_j + 1, ...)] += _to_view(
976-
reduced_view[(*pos, j, ...)]
977-
)
981+
if new_view.dtype.names:
982+
for field in new_view.dtype.names:
983+
new_view[(*pos, new_j + new_j_base, ...)][
984+
field
985+
] += reduced_view[(*pos, j, ...)][field]
986+
else:
987+
new_view[(*pos, new_j + new_j_base, ...)] += (
988+
reduced_view[(*pos, j, ...)]
989+
)
978990
j += 1
979991

980992
reduced = new_reduced

src/boost_histogram/tag.py

+50-5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from builtins import sum
77
from typing import TYPE_CHECKING, Sequence, TypeVar
88

9+
import numpy as np
10+
911
if TYPE_CHECKING:
1012
from uhi.typing.plottable import PlottableAxis
1113

@@ -112,26 +114,43 @@ def __call__(self, axis: AxisLike) -> int: # noqa: ARG002
112114

113115
class rebin:
114116
__slots__ = (
117+
"axis",
118+
"edges",
115119
"factor",
116120
"groups",
117121
)
118122

119123
def __init__(
120124
self,
121-
factor: int | None = None,
125+
factor_or_axis: int | PlottableAxis | None = None,
126+
/,
122127
*,
128+
factor: int | None = None,
123129
groups: Sequence[int] | None = None,
130+
edges: Sequence[int | float] | None = None,
131+
axis: PlottableAxis | None = None,
124132
) -> None:
125-
if not sum(i is None for i in [factor, groups]) == 1:
126-
raise ValueError("Exactly one, a factor or groups should be provided")
127-
self.factor = factor
133+
if (
134+
sum(i is not None for i in [factor_or_axis, factor, groups, edges, axis])
135+
!= 1
136+
):
137+
raise ValueError("Exactly one argument should be provided")
128138
self.groups = groups
139+
self.edges = edges
140+
self.axis = axis
141+
self.factor = factor
142+
if isinstance(factor_or_axis, int):
143+
self.factor = factor_or_axis
144+
elif factor_or_axis is not None:
145+
self.axis = factor_or_axis
129146

130147
def __repr__(self) -> str:
131148
repr_str = f"{self.__class__.__name__}"
132-
args: dict[str, int | Sequence[int] | None] = {
149+
args: dict[str, int | Sequence[int | float] | PlottableAxis | None] = {
133150
"factor": self.factor,
134151
"groups": self.groups,
152+
"edges": self.edges,
153+
"axis": self.axis,
135154
}
136155
for k, v in args.items():
137156
if v is not None:
@@ -147,4 +166,30 @@ def group_mapping(self, axis: PlottableAxis) -> Sequence[int]:
147166
return self.groups
148167
if self.factor is not None:
149168
return [self.factor] * len(axis)
169+
if self.edges is not None or self.axis is not None:
170+
newedges = None
171+
if self.axis is not None and hasattr(self.axis, "edges"):
172+
newedges = self.axis.edges
173+
elif self.edges is not None:
174+
newedges = self.edges
175+
176+
if newedges is not None and hasattr(axis, "edges"):
177+
assert newedges[0] == axis.edges[0], "Edges must start at first bin"
178+
assert newedges[-1] == axis.edges[-1], "Edges must end at last bin"
179+
assert all(
180+
np.isclose(
181+
axis.edges[np.abs(axis.edges - edge).argmin()],
182+
edge,
183+
)
184+
for edge in newedges
185+
), "Edges must be in the axis"
186+
matched_ixes = np.where(
187+
np.isin(
188+
axis.edges,
189+
newedges,
190+
)
191+
)[0]
192+
return [
193+
int(ix - matched_ixes[i]) for i, ix in enumerate(matched_ixes[1:])
194+
]
150195
raise ValueError("No rebinning factor or groups provided")

tests/test_histogram.py

+27
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,33 @@ def test_rebin_1d(metadata):
650650
assert_array_equal(hs.axes.edges[0], [1.0, 1.2, 1.6, 2.2, 5.0])
651651
assert h.axes[0].metadata is hs.axes[0].metadata
652652

653+
hs = h[bh.rebin(edges=[1.0, 1.2, 1.6, 2.2, 5.0])]
654+
assert_array_equal(hs.view(), [1, 0, 0, 3])
655+
assert_array_equal(hs.axes.edges[0], [1.0, 1.2, 1.6, 2.2, 5.0])
656+
657+
hs = h[bh.rebin(axis=bh.axis.Variable([1.0, 1.2, 1.6, 2.2, 5.0]))]
658+
assert_array_equal(hs.view(), [1, 0, 0, 3])
659+
assert_array_equal(hs.axes.edges[0], [1.0, 1.2, 1.6, 2.2, 5.0])
660+
661+
662+
def test_rebin_1d_flow():
663+
h = bh.Histogram(bh.axis.Regular(5, 0, 5, underflow=True, overflow=True))
664+
h.fill([-1, 1.1, 2.2, 3.3, 4.4, 5.5])
665+
hs = h[bh.rebin(edges=[0, 3, 5.0])]
666+
assert_array_equal(hs.view(), [2, 2])
667+
assert_array_equal(hs.view(flow=True), [1, 2, 2, 1])
668+
assert_array_equal(hs.axes.edges[0], [0.0, 3.0, 5.0])
669+
670+
h = bh.Histogram(bh.axis.Regular(5, 0, 5, underflow=False, overflow=False))
671+
h.fill([-1, 1.1, 2.2, 3.3, 4.4, 5.5])
672+
hs = h[bh.rebin(edges=[0, 3, 5.0])]
673+
assert_array_equal(hs.view(flow=True), [0, 2, 2, 0])
674+
675+
h = bh.Histogram(bh.axis.Regular(5, 0, 5, underflow=True, overflow=False))
676+
h.fill([-1, 1.1, 2.2, 3.3, 4.4, 5.5])
677+
hs = h[bh.rebin(edges=[0, 3, 5.0])]
678+
assert_array_equal(hs.view(flow=True), [1, 2, 2, 0])
679+
653680

654681
def test_shrink_rebin_1d():
655682
h = bh.Histogram(bh.axis.Regular(20, 0, 4))

0 commit comments

Comments
 (0)