Skip to content

Commit

Permalink
refactor: move IterativeKernelHerding to benchmark_util.py
Browse files Browse the repository at this point in the history
  • Loading branch information
gw265981 committed Feb 24, 2025
1 parent cd14075 commit 538071a
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 101 deletions.
2 changes: 1 addition & 1 deletion benchmark/blobs_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from sklearn.datasets import make_blobs

from coreax import Data, SlicedScoreMatching
from coreax.benchmark_util import IterativeKernelHerding
from coreax.kernels import (
SquaredExponentialKernel,
SteinKernel,
Expand All @@ -46,7 +47,6 @@
from coreax.metrics import KSD, MMD
from coreax.solvers import (
CompressPlusPlus,
IterativeKernelHerding,
KernelHerding,
KernelThinning,
RandomSample,
Expand Down
46 changes: 43 additions & 3 deletions coreax/benchmark_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@
"""

from collections.abc import Callable
from typing import Optional, Union
from typing import Optional, TypeVar, Union

import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, Float

from coreax import Data
from coreax import Coresubset, Data, SupervisedData
from coreax.kernels import SquaredExponentialKernel, SteinKernel, median_heuristic
from coreax.score_matching import KernelDensityMatching
from coreax.solvers import (
CompressPlusPlus,
IterativeKernelHerding,
HerdingState,
KernelHerding,
KernelThinning,
MapReduce,
Expand All @@ -43,6 +43,46 @@
)
from coreax.util import KeyArrayLike

_Data = TypeVar("_Data", Data, SupervisedData)


class IterativeKernelHerding(KernelHerding[_Data]): # pylint: disable=too-many-ancestors
r"""
Iterative Kernel Herding - perform multiple refinements of Kernel Herding.
Wrapper around :meth:`~coreax.solvers.KernelHerding.reduce_iterative` for
benchmarking purposes.
:param num_iterations: Number of refinement iterations
:param t_schedule: An :class:`Array` of length `num_iterations`, where
`t_schedule[i]` is the temperature parameter used for iteration i. If None,
standard Kernel Herding is used
"""

num_iterations: int = 1
t_schedule: Optional[Array] = None

def reduce(
self,
dataset: _Data,
solver_state: Optional[HerdingState] = None,
) -> tuple[Coresubset[_Data], HerdingState]:
"""
Perform Kernel Herding reduction followed by additional refinement iterations.
:param dataset: The dataset to process.
:param solver_state: Optional solver state.
:return: Refined coresubset and final solver state.
"""
coreset, reduced_solver_state = self.reduce_iterative(
dataset,
solver_state,
num_iterations=self.num_iterations,
t_schedule=self.t_schedule,
)

return coreset, reduced_solver_state


def calculate_delta(n: int) -> Float[Array, "1"]:
r"""
Expand Down
2 changes: 0 additions & 2 deletions coreax/solvers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
GreedyKernelPoints,
GreedyKernelPointsState,
HerdingState,
IterativeKernelHerding,
KernelHerding,
KernelThinning,
RandomSample,
Expand Down Expand Up @@ -62,5 +61,4 @@
"CaratheodoryRecombination",
"TreeRecombination",
"CompressPlusPlus",
"IterativeKernelHerding",
]
74 changes: 0 additions & 74 deletions coreax/solvers/coresubset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,77 +1490,3 @@ def _compress_plus_plus(indices: Array) -> Array:

plus_plus_indices = _compress_plus_plus(clipped_indices)
return Coresubset(Data(plus_plus_indices), dataset), None


class IterativeKernelHerding(ExplicitSizeSolver):
r"""
Iterative Kernel Herding - perform multiple refinements of Kernel Herding.
Reduce using :class:`~coreax.solvers.KernelHerding` then refine set number of times.
All the parameters except the `num_iterations` are passed to
:class:`~coreax.solvers.KernelHerding`.
:param coreset_size: The desired size of the solved coreset.
:param kernel: :class:`~coreax.kernels.ScalarValuedKernel` instance implementing a
kernel function.
:math:`k: \\mathbb{R}^d \times \\mathbb{R}^d \rightarrow \\mathbb{R}`
:param unique: Boolean that ensures the resulting coresubset will only contain
unique elements.
:param block_size: Block size passed to
:meth:`~coreax.kernels.ScalarValuedKernel.compute_mean`.
:param unroll: Unroll parameter passed to
:meth:`~coreax.kernels.ScalarValuedKernel.compute_mean`.
:param probabilistic: If :data:`True`, the elements are chosen probabilistically at
each iteration. Otherwise, standard Kernel Herding is run.
:param temperature: Temperature parameter, which controls how uniform the
probabilities are for probabilistic selection.
:param random_key: Key for random number generation, only used if probabilistic
:param num_iterations: Number of refinement iterations.
"""

num_iterations: int
kernel: ScalarValuedKernel
unique: bool = True
block_size: Optional[Union[int, tuple[Optional[int], Optional[int]]]] = None
unroll: Union[int, bool, tuple[Union[int, bool], Union[int, bool]]] = 1
probabilistic: bool = False
temperature: Union[float, Scalar] = eqx.field(default=1.0)
random_key: KeyArrayLike = eqx.field(default_factory=lambda: jax.random.key(0))

def reduce(
self,
dataset: _Data,
solver_state: Optional[HerdingState] = None,
) -> tuple[Coresubset[_Data], HerdingState]:
"""
Perform Kernel Herding reduction followed by additional refinement iterations.
:param dataset: The dataset to process.
:param solver_state: Optional solver state.
:return: Refined coresubset and final solver state.
"""
herding_solver = KernelHerding(
coreset_size=self.coreset_size,
kernel=self.kernel,
unique=self.unique,
block_size=self.block_size,
unroll=self.unroll,
probabilistic=self.probabilistic,
temperature=self.temperature,
random_key=self.random_key,
)

coreset, reduced_solver_state = herding_solver.reduce(dataset, solver_state)

def refine_step(_, state):
coreset, reduced_solver_state = state
coreset, reduced_solver_state = herding_solver.refine(
coreset, reduced_solver_state
)
return (coreset, reduced_solver_state)

coreset, reduced_solver_state = lax.fori_loop(
0, self.num_iterations, refine_step, (coreset, reduced_solver_state)
)

return coreset, reduced_solver_state
7 changes: 5 additions & 2 deletions tests/unit/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@
train_and_evaluate,
)
from coreax import Data
from coreax.benchmark_util import calculate_delta, initialise_solvers
from coreax.benchmark_util import (
IterativeKernelHerding,
calculate_delta,
initialise_solvers,
)
from coreax.solvers import (
CompressPlusPlus,
IterativeKernelHerding,
KernelHerding,
KernelThinning,
MapReduce,
Expand Down
19 changes: 0 additions & 19 deletions tests/unit/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
GreedyKernelPoints,
GreedyKernelPointsState,
HerdingState,
IterativeKernelHerding,
KernelHerding,
KernelThinning,
MapReduce,
Expand Down Expand Up @@ -2630,21 +2629,3 @@ def test_invalid_coreset_size_incompatible(self):
sqrt_kernel=SquaredExponentialKernel(),
)
solver.reduce(dataset)


class TestIterativeKernelHerding(ExplicitSizeSolverTest):
"""Test cases for :class:`coreax.solvers.coresubset.KernelThinning`."""

@override
@pytest.fixture(scope="class", params=[True, False])
def solver_factory(self, request: pytest.FixtureRequest) -> jtu.Partial:
kernel = PCIMQKernel()
coreset_size = self.shape[0] // 10
return jtu.Partial(
IterativeKernelHerding,
coreset_size=coreset_size,
random_key=self.random_key,
kernel=kernel,
probabilistic=request.param,
num_iterations=2,
)

0 comments on commit 538071a

Please sign in to comment.