Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/benchmarking new solvers #971

Merged
merged 11 commits into from
Feb 24, 2025
7 changes: 7 additions & 0 deletions .cspell/custom_misc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ delaxes
diag
docstrings
ecolor
edgecolor
eigendecomposition
elementwise
elinewidth
errorbar
figtext
fontsize
fontweight
forall
frameon
GCHQ
Gramian
gramians
Expand All @@ -37,10 +41,12 @@ KSD
linestyle
linewidth
mapsto
markersize
Matern
maxs
ml.p3.8xlarge
MNIST
ncol
ndmin
parsable
PCIMQ
Expand Down Expand Up @@ -72,3 +78,4 @@ xticks
yerr
ylim
yscale
yticks
50 changes: 47 additions & 3 deletions benchmark/blobs_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Benchmark performance of different coreset algorithms on a synthetic dataset.

The benchmarking process follows these steps:
1. Generate a synthetic dataset of 1000 two-dimensional points using
1. Generate a synthetic dataset of 1_024 two-dimensional points using
:func:`sklearn.datasets.make_blobs`.
2. Generate coresets of varying sizes: 10, 50, 100, and 200 points using different
coreset algorithms.
Expand Down Expand Up @@ -45,6 +45,8 @@
)
from coreax.metrics import KSD, MMD
from coreax.solvers import (
CompressPlusPlus,
IterativeKernelHerding,
KernelHerding,
KernelThinning,
RandomSample,
Expand Down Expand Up @@ -153,6 +155,39 @@ def setup_solvers(
sqrt_kernel=sqrt_kernel,
),
),
(
"CompressPlusPlus",
CompressPlusPlus(
coreset_size=coreset_size,
kernel=sq_exp_kernel,
random_key=random_key,
delta=delta,
sqrt_kernel=sqrt_kernel,
g=4,
),
),
(
"ProbabilisticIterativeHerding",
IterativeKernelHerding(
coreset_size=coreset_size,
kernel=sq_exp_kernel,
probabilistic=True,
temperature=0.001,
random_key=random_key,
num_iterations=5,
),
),
(
"IterativeHerding",
IterativeKernelHerding(
coreset_size=coreset_size,
kernel=sq_exp_kernel,
probabilistic=False,
temperature=0.001,
random_key=random_key,
num_iterations=5,
),
),
]


Expand Down Expand Up @@ -229,8 +264,17 @@ def compute_metrics(


def main() -> None: # pylint: disable=too-many-locals
"""Benchmark various algorithms on a synthetic dataset over multiple seeds."""
n_samples = 1_000
"""
Benchmark various algorithms on a synthetic dataset over multiple seeds.

Generate a synthetic dataset of 1,024 two-dimensional points, as this size some
coreset algorithms, such as Compress++, which perform best when the dataset size is
a power of 4. The function evaluates the performance of different coreset algorithms
on different sizes of coresets (25, 50, 100, and 200 points), over multiple seeds.
The performance of each algorithm is assessed using the Maximum Mean Discrepancy
(MMD) and Kernel Stein Discrepancy (KSD) metrics.
"""
n_samples = 1_024
seeds = [42, 45, 46, 47, 48] # List of seeds to average over
coreset_sizes = [25, 50, 100, 200]

Expand Down
116 changes: 86 additions & 30 deletions benchmark/blobs_benchmark_visualiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@

def plot_benchmarking_results(data):
"""
Visualise the benchmarking results.
Visualise the benchmarking results in five separate plots.

:param data: A dictionary where keys are the coreset sizes (as strings) and values
that are dictionaries containing the metrics for each algorithm.
are dictionaries containing the metrics for each algorithm.

Example:
{
Expand All @@ -45,46 +45,54 @@ def plot_benchmarking_results(data):
}

"""
title_size = 22
label_size = 18
tick_size = 16
legend_size = 16

first_coreset_size = next(iter(data.keys()))
first_algorithm = next(
iter(data[first_coreset_size].values())
) # Get one example algorithm
first_algorithm = next(iter(data[first_coreset_size].values()))
metrics = list(first_algorithm.keys())
n_metrics = len(metrics)

n_rows = (n_metrics + 1) // 2
fig, axs = plt.subplots(n_rows, 2, figsize=(14, 6 * n_rows))
fig.delaxes(axs[2, 1])
axs = axs.flatten()

# Iterate over each metric and create its subplot
for i, metric in enumerate(metrics):
ax = axs[i]
ax.set_title(
for metric in metrics:
plt.figure(figsize=(10, 8))
plt.title(
f"{metric.replace('_', ' ').title()} vs Coreset Size",
fontsize=14,
fontsize=title_size,
fontweight="bold",
)

# For each algorithm, plot its performance across different subset sizes
for algo in data[list(data.keys())[0]].keys(): # Iterating through algorithms
# Create lists of subset sizes (10, 50, 100, 200)
for algo in data[first_coreset_size].keys():
coreset_sizes = sorted(map(int, data.keys()))
metric_values = [
data[str(subset_size)][algo].get(metric, float("nan"))
for subset_size in coreset_sizes
data[str(size)][algo].get(metric, float("nan"))
for size in coreset_sizes
]

ax.plot(coreset_sizes, metric_values, marker="o", label=algo)
plt.plot(
coreset_sizes,
metric_values,
marker="o",
markersize=8,
linewidth=2.5,
label=algo,
)

plt.xlabel("Coreset Size", fontsize=label_size, fontweight="bold")
plt.ylabel(
f"{metric.replace('_', ' ').title()}",
fontsize=label_size,
fontweight="bold",
)
plt.yscale("log")

ax.set_xlabel("Coreset Size")
ax.set_ylabel(f"{metric.replace('_', ' ').title()}")
ax.set_yscale("log") # log scale for better visualization
ax.legend()
plt.xticks(fontsize=tick_size)
plt.yticks(fontsize=tick_size)

# Adjust layout to avoid overlap
plt.subplots_adjust(hspace=15.0, wspace=1.0)
plt.tight_layout(pad=3.0, rect=(0.0, 0.0, 1.0, 0.96))
plt.show()
plt.legend(fontsize=legend_size, loc="best", frameon=True)

plt.grid(True, linestyle="--", alpha=0.7)
plt.show()


# Function to print metrics table for each sample size
Expand Down Expand Up @@ -134,6 +142,53 @@ def print_metrics_table(data: dict, coreset_size: str) -> None:
print(separator)


def print_rst_metrics_table(data: dict, original_sample_size: int) -> None:
"""
Print metrics tables in reStructuredText format with highlighted best values.

:param data: Dictionary with coreset sizes as keys and nested metrics data
:param original_sample_size: The size of the original sample to display
"""
metrics = [
"Unweighted_MMD",
"Unweighted_KSD",
"Weighted_MMD",
"Weighted_KSD",
"Time",
]

for coreset_size, methods_data in sorted(data.items(), key=lambda x: int(x[0])):
if coreset_size == "n_samples": # Skip the sample size entry
continue

print(
f".. list-table:: Coreset Size {coreset_size} "
f"(Original Sample Size {original_sample_size:,})"
)
print(" :header-rows: 1")
print(" :widths: 20 15 15 15 15 15")
print()
print(" * - Method")
for metric in metrics:
print(f" - {metric}")

# Find best (minimum) values for each metric
best_values = {
metric: min(methods_data[method][metric] for method in methods_data)
for metric in metrics
}

for method in methods_data:
print(f" * - {method}")
for metric in metrics:
value = methods_data[method][metric]
if value == best_values[metric]:
print(f" - **{value:.6f}**") # Highlight best value
else:
print(f" - {value:.6f}")
print()


def main() -> None:
"""Load the data and print metrics in table format per sample size."""
# Load the JSON data
Expand All @@ -149,6 +204,7 @@ def main() -> None:
continue
print_metrics_table(data, coreset_size)

print_rst_metrics_table(data, original_sample_size=1024)
plot_benchmarking_results(data)


Expand Down
17 changes: 9 additions & 8 deletions benchmark/david_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Each coreset algorithm is timed to measure and report the time taken for each step.
"""

import math
import os
import time
from pathlib import Path
Expand All @@ -40,7 +41,7 @@
from jax import random

from coreax import Data
from coreax.benchmark_util import get_solver_name, initialise_solvers
from coreax.benchmark_util import initialise_solvers
from examples.david_map_reduce_weighted import downsample_opencv

MAX_8BIT = 255
Expand Down Expand Up @@ -77,37 +78,37 @@ def benchmark_coreset_algorithms(
pre_coreset_data = np.column_stack((pre_coreset_data, pixel_values)).astype(
np.float32
)

# Set up the original data object and coreset parameters
data = Data(jnp.asarray(pre_coreset_data))
over_sampling_factor = math.floor(math.log(data.shape[0], 4))
coreset_size = 8_000 // (downsampling_factor**2)

# Initialize each coreset solver
key = random.PRNGKey(0)
solver_factories = initialise_solvers(data, key)
solver_factories = initialise_solvers(
data, key, cpp_oversampling_factor=over_sampling_factor
)

# Dictionary to store coresets generated by each method
coresets = {}
solver_times = {}

for solver_creator in solver_factories:
for solver_name, solver_creator in solver_factories.items():
solver = solver_creator(coreset_size)
solver_name = get_solver_name(solver_creator)
start_time = time.perf_counter()
coreset, _ = eqx.filter_jit(solver.reduce)(data)
duration = time.perf_counter() - start_time
coresets[solver_name] = coreset.points.data
solver_times[solver_name] = duration

plt.figure(figsize=(15, 10))
plt.subplot(2, 3, 1)
plt.subplot(3, 3, 1)
plt.imshow(original_data, cmap="gray")
plt.title("Original Image")
plt.axis("off")

# Plot each coreset method
for i, (solver_name, coreset_data) in enumerate(coresets.items(), start=2):
plt.subplot(2, 3, i)
plt.subplot(3, 3, i)
plt.scatter(
coreset_data[:, 1],
-coreset_data[:, 0],
Expand Down
9 changes: 5 additions & 4 deletions benchmark/mnist_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from torchvision.datasets import VisionDataset

from coreax import Data
from coreax.benchmark_util import get_solver_name, initialise_solvers
from coreax.benchmark_util import initialise_solvers
from coreax.util import KeyArrayLike


Expand Down Expand Up @@ -530,11 +530,12 @@ def main() -> None:
for i in range(5):
print(f"Run {i + 1} of 5:")
key = jax.random.PRNGKey(i)
solver_factories = initialise_solvers(train_data_umap, key)
for solver_creator in solver_factories:
solver_factories = initialise_solvers(
train_data_umap, key, cpp_oversampling_factor=7, leaf_size=15_000
)
for solver_name, solver_creator in solver_factories.items():
for size in [25, 50, 100, 500, 1_000, 5_000]:
solver = solver_creator(size)
solver_name = get_solver_name(solver_creator)
start_time = time.perf_counter()
# pylint: enable=duplicate-code
coreset, _ = eqx.filter_jit(solver.reduce)(train_data_umap)
Expand Down
15 changes: 8 additions & 7 deletions benchmark/mnist_benchmark_coresets_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,13 @@

import equinox as eqx
import jax

from benchmark.mnist_benchmark import (
from mnist_benchmark import (
density_preserving_umap,
get_solver_name,
initialise_solvers,
prepare_datasets,
)

from coreax import Data
from coreax.benchmark_util import initialise_solvers


def save_results(results: dict) -> None:
Expand Down Expand Up @@ -103,12 +102,14 @@ def main() -> None:
for i in range(5):
print(f"Run {i + 1} of 5:")
key = jax.random.PRNGKey(i)
solver_factories = initialise_solvers(train_data_umap, key)
for solver_creator in solver_factories:
solver_factories = initialise_solvers(
train_data_umap, key, cpp_oversampling_factor=7, leaf_size=15_000
)
for solver_name, solver_creator in solver_factories.items():
for size in [25, 50, 100, 500, 1_000]:
solver = solver_creator(size)
solver_name = get_solver_name(solver_creator)
start_time = time.perf_counter()
# pylint: enable=duplicate-code
_, _ = eqx.filter_jit(solver.reduce)(train_data_umap)
time_taken = time.perf_counter() - start_time

Expand Down
Loading