diff --git a/.cspell/custom_misc.txt b/.cspell/custom_misc.txt index acb5405f..6bc122cc 100644 --- a/.cspell/custom_misc.txt +++ b/.cspell/custom_misc.txt @@ -16,12 +16,16 @@ delaxes diag docstrings ecolor +edgecolor eigendecomposition elementwise +elinewidth errorbar figtext fontsize +fontweight forall +frameon GCHQ Gramian gramians @@ -37,10 +41,12 @@ KSD linestyle linewidth mapsto +markersize Matern maxs ml.p3.8xlarge MNIST +ncol ndmin parsable PCIMQ @@ -72,3 +78,4 @@ xticks yerr ylim yscale +yticks diff --git a/benchmark/blobs_benchmark.py b/benchmark/blobs_benchmark.py index 2b7fc22f..82d7afb7 100644 --- a/benchmark/blobs_benchmark.py +++ b/benchmark/blobs_benchmark.py @@ -16,7 +16,7 @@ Benchmark performance of different coreset algorithms on a synthetic dataset. The benchmarking process follows these steps: -1. Generate a synthetic dataset of 1000 two-dimensional points using +1. Generate a synthetic dataset of 1_024 two-dimensional points using :func:`sklearn.datasets.make_blobs`. 2. Generate coresets of varying sizes: 10, 50, 100, and 200 points using different coreset algorithms. @@ -45,6 +45,8 @@ ) from coreax.metrics import KSD, MMD from coreax.solvers import ( + CompressPlusPlus, + IterativeKernelHerding, KernelHerding, KernelThinning, RandomSample, @@ -153,6 +155,39 @@ def setup_solvers( sqrt_kernel=sqrt_kernel, ), ), + ( + "CompressPlusPlus", + CompressPlusPlus( + coreset_size=coreset_size, + kernel=sq_exp_kernel, + random_key=random_key, + delta=delta, + sqrt_kernel=sqrt_kernel, + g=4, + ), + ), + ( + "ProbabilisticIterativeHerding", + IterativeKernelHerding( + coreset_size=coreset_size, + kernel=sq_exp_kernel, + probabilistic=True, + temperature=0.001, + random_key=random_key, + num_iterations=5, + ), + ), + ( + "IterativeHerding", + IterativeKernelHerding( + coreset_size=coreset_size, + kernel=sq_exp_kernel, + probabilistic=False, + temperature=0.001, + random_key=random_key, + num_iterations=5, + ), + ), ] @@ -229,8 +264,17 @@ def compute_metrics( def main() -> None: # pylint: disable=too-many-locals - """Benchmark various algorithms on a synthetic dataset over multiple seeds.""" - n_samples = 1_000 + """ + Benchmark various algorithms on a synthetic dataset over multiple seeds. + + Generate a synthetic dataset of 1,024 two-dimensional points, as this size some + coreset algorithms, such as Compress++, which perform best when the dataset size is + a power of 4. The function evaluates the performance of different coreset algorithms + on different sizes of coresets (25, 50, 100, and 200 points), over multiple seeds. + The performance of each algorithm is assessed using the Maximum Mean Discrepancy + (MMD) and Kernel Stein Discrepancy (KSD) metrics. + """ + n_samples = 1_024 seeds = [42, 45, 46, 47, 48] # List of seeds to average over coreset_sizes = [25, 50, 100, 200] diff --git a/benchmark/blobs_benchmark_visualiser.py b/benchmark/blobs_benchmark_visualiser.py index 92644f3c..04ff101b 100644 --- a/benchmark/blobs_benchmark_visualiser.py +++ b/benchmark/blobs_benchmark_visualiser.py @@ -22,10 +22,10 @@ def plot_benchmarking_results(data): """ - Visualise the benchmarking results. + Visualise the benchmarking results in five separate plots. :param data: A dictionary where keys are the coreset sizes (as strings) and values - that are dictionaries containing the metrics for each algorithm. + are dictionaries containing the metrics for each algorithm. Example: { @@ -45,46 +45,54 @@ def plot_benchmarking_results(data): } """ + title_size = 22 + label_size = 18 + tick_size = 16 + legend_size = 16 + first_coreset_size = next(iter(data.keys())) - first_algorithm = next( - iter(data[first_coreset_size].values()) - ) # Get one example algorithm + first_algorithm = next(iter(data[first_coreset_size].values())) metrics = list(first_algorithm.keys()) - n_metrics = len(metrics) - n_rows = (n_metrics + 1) // 2 - fig, axs = plt.subplots(n_rows, 2, figsize=(14, 6 * n_rows)) - fig.delaxes(axs[2, 1]) - axs = axs.flatten() - - # Iterate over each metric and create its subplot - for i, metric in enumerate(metrics): - ax = axs[i] - ax.set_title( + for metric in metrics: + plt.figure(figsize=(10, 8)) + plt.title( f"{metric.replace('_', ' ').title()} vs Coreset Size", - fontsize=14, + fontsize=title_size, + fontweight="bold", ) - # For each algorithm, plot its performance across different subset sizes - for algo in data[list(data.keys())[0]].keys(): # Iterating through algorithms - # Create lists of subset sizes (10, 50, 100, 200) + for algo in data[first_coreset_size].keys(): coreset_sizes = sorted(map(int, data.keys())) metric_values = [ - data[str(subset_size)][algo].get(metric, float("nan")) - for subset_size in coreset_sizes + data[str(size)][algo].get(metric, float("nan")) + for size in coreset_sizes ] - ax.plot(coreset_sizes, metric_values, marker="o", label=algo) + plt.plot( + coreset_sizes, + metric_values, + marker="o", + markersize=8, + linewidth=2.5, + label=algo, + ) + + plt.xlabel("Coreset Size", fontsize=label_size, fontweight="bold") + plt.ylabel( + f"{metric.replace('_', ' ').title()}", + fontsize=label_size, + fontweight="bold", + ) + plt.yscale("log") - ax.set_xlabel("Coreset Size") - ax.set_ylabel(f"{metric.replace('_', ' ').title()}") - ax.set_yscale("log") # log scale for better visualization - ax.legend() + plt.xticks(fontsize=tick_size) + plt.yticks(fontsize=tick_size) - # Adjust layout to avoid overlap - plt.subplots_adjust(hspace=15.0, wspace=1.0) - plt.tight_layout(pad=3.0, rect=(0.0, 0.0, 1.0, 0.96)) - plt.show() + plt.legend(fontsize=legend_size, loc="best", frameon=True) + + plt.grid(True, linestyle="--", alpha=0.7) + plt.show() # Function to print metrics table for each sample size @@ -134,6 +142,53 @@ def print_metrics_table(data: dict, coreset_size: str) -> None: print(separator) +def print_rst_metrics_table(data: dict, original_sample_size: int) -> None: + """ + Print metrics tables in reStructuredText format with highlighted best values. + + :param data: Dictionary with coreset sizes as keys and nested metrics data + :param original_sample_size: The size of the original sample to display + """ + metrics = [ + "Unweighted_MMD", + "Unweighted_KSD", + "Weighted_MMD", + "Weighted_KSD", + "Time", + ] + + for coreset_size, methods_data in sorted(data.items(), key=lambda x: int(x[0])): + if coreset_size == "n_samples": # Skip the sample size entry + continue + + print( + f".. list-table:: Coreset Size {coreset_size} " + f"(Original Sample Size {original_sample_size:,})" + ) + print(" :header-rows: 1") + print(" :widths: 20 15 15 15 15 15") + print() + print(" * - Method") + for metric in metrics: + print(f" - {metric}") + + # Find best (minimum) values for each metric + best_values = { + metric: min(methods_data[method][metric] for method in methods_data) + for metric in metrics + } + + for method in methods_data: + print(f" * - {method}") + for metric in metrics: + value = methods_data[method][metric] + if value == best_values[metric]: + print(f" - **{value:.6f}**") # Highlight best value + else: + print(f" - {value:.6f}") + print() + + def main() -> None: """Load the data and print metrics in table format per sample size.""" # Load the JSON data @@ -149,6 +204,7 @@ def main() -> None: continue print_metrics_table(data, coreset_size) + print_rst_metrics_table(data, original_sample_size=1024) plot_benchmarking_results(data) diff --git a/benchmark/david_benchmark.py b/benchmark/david_benchmark.py index a4995838..7a3ccbef 100644 --- a/benchmark/david_benchmark.py +++ b/benchmark/david_benchmark.py @@ -28,6 +28,7 @@ Each coreset algorithm is timed to measure and report the time taken for each step. """ +import math import os import time from pathlib import Path @@ -40,7 +41,7 @@ from jax import random from coreax import Data -from coreax.benchmark_util import get_solver_name, initialise_solvers +from coreax.benchmark_util import initialise_solvers from examples.david_map_reduce_weighted import downsample_opencv MAX_8BIT = 255 @@ -77,22 +78,22 @@ def benchmark_coreset_algorithms( pre_coreset_data = np.column_stack((pre_coreset_data, pixel_values)).astype( np.float32 ) - # Set up the original data object and coreset parameters data = Data(jnp.asarray(pre_coreset_data)) + over_sampling_factor = math.floor(math.log(data.shape[0], 4)) coreset_size = 8_000 // (downsampling_factor**2) - # Initialize each coreset solver key = random.PRNGKey(0) - solver_factories = initialise_solvers(data, key) + solver_factories = initialise_solvers( + data, key, cpp_oversampling_factor=over_sampling_factor + ) # Dictionary to store coresets generated by each method coresets = {} solver_times = {} - for solver_creator in solver_factories: + for solver_name, solver_creator in solver_factories.items(): solver = solver_creator(coreset_size) - solver_name = get_solver_name(solver_creator) start_time = time.perf_counter() coreset, _ = eqx.filter_jit(solver.reduce)(data) duration = time.perf_counter() - start_time @@ -100,14 +101,14 @@ def benchmark_coreset_algorithms( solver_times[solver_name] = duration plt.figure(figsize=(15, 10)) - plt.subplot(2, 3, 1) + plt.subplot(3, 3, 1) plt.imshow(original_data, cmap="gray") plt.title("Original Image") plt.axis("off") # Plot each coreset method for i, (solver_name, coreset_data) in enumerate(coresets.items(), start=2): - plt.subplot(2, 3, i) + plt.subplot(3, 3, i) plt.scatter( coreset_data[:, 1], -coreset_data[:, 0], diff --git a/benchmark/mnist_benchmark.py b/benchmark/mnist_benchmark.py index 31853e5f..083610e3 100644 --- a/benchmark/mnist_benchmark.py +++ b/benchmark/mnist_benchmark.py @@ -56,7 +56,7 @@ from torchvision.datasets import VisionDataset from coreax import Data -from coreax.benchmark_util import get_solver_name, initialise_solvers +from coreax.benchmark_util import initialise_solvers from coreax.util import KeyArrayLike @@ -530,11 +530,12 @@ def main() -> None: for i in range(5): print(f"Run {i + 1} of 5:") key = jax.random.PRNGKey(i) - solver_factories = initialise_solvers(train_data_umap, key) - for solver_creator in solver_factories: + solver_factories = initialise_solvers( + train_data_umap, key, cpp_oversampling_factor=7, leaf_size=15_000 + ) + for solver_name, solver_creator in solver_factories.items(): for size in [25, 50, 100, 500, 1_000, 5_000]: solver = solver_creator(size) - solver_name = get_solver_name(solver_creator) start_time = time.perf_counter() # pylint: enable=duplicate-code coreset, _ = eqx.filter_jit(solver.reduce)(train_data_umap) diff --git a/benchmark/mnist_benchmark_coresets_only.py b/benchmark/mnist_benchmark_coresets_only.py index 8191ff7c..0ddfe115 100644 --- a/benchmark/mnist_benchmark_coresets_only.py +++ b/benchmark/mnist_benchmark_coresets_only.py @@ -33,14 +33,13 @@ import equinox as eqx import jax - -from benchmark.mnist_benchmark import ( +from mnist_benchmark import ( density_preserving_umap, - get_solver_name, - initialise_solvers, prepare_datasets, ) + from coreax import Data +from coreax.benchmark_util import initialise_solvers def save_results(results: dict) -> None: @@ -103,12 +102,14 @@ def main() -> None: for i in range(5): print(f"Run {i + 1} of 5:") key = jax.random.PRNGKey(i) - solver_factories = initialise_solvers(train_data_umap, key) - for solver_creator in solver_factories: + solver_factories = initialise_solvers( + train_data_umap, key, cpp_oversampling_factor=7, leaf_size=15_000 + ) + for solver_name, solver_creator in solver_factories.items(): for size in [25, 50, 100, 500, 1_000]: solver = solver_creator(size) - solver_name = get_solver_name(solver_creator) start_time = time.perf_counter() + # pylint: enable=duplicate-code _, _ = eqx.filter_jit(solver.reduce)(train_data_umap) time_taken = time.perf_counter() - start_time diff --git a/benchmark/mnist_benchmark_results.json b/benchmark/mnist_benchmark_results.json index feadc718..4f357074 100644 --- a/benchmark/mnist_benchmark_results.json +++ b/benchmark/mnist_benchmark_results.json @@ -1,671 +1,939 @@ { - "RandomSample": { + "Random Sample": { "25": { "0": { "accuracy": 0.47499004006385803, - "time_taken": 23.26633542699983 + "time_taken": 23.352234674006468 }, "1": { "accuracy": 0.5094035863876343, - "time_taken": 24.154528033000133 + "time_taken": 24.481476221000776 }, "2": { - "accuracy": 0.4431772530078888, - "time_taken": 26.308422624999366 + "accuracy": 0.4429771900177002, + "time_taken": 25.436731434005196 }, "3": { - "accuracy": 0.5268108248710632, - "time_taken": 22.50174796200008 + "accuracy": 0.5266107320785522, + "time_taken": 21.81958907599619 }, "4": { "accuracy": 0.4708881676197052, - "time_taken": 25.894998888999908 + "time_taken": 25.530452502993285 } }, "50": { "0": { "accuracy": 0.6092996001243591, - "time_taken": 7.9610507360002885 + "time_taken": 8.290402896993328 }, "1": { "accuracy": 0.6098998785018921, - "time_taken": 10.142288449999796 + "time_taken": 10.664244551997399 }, "2": { "accuracy": 0.5316997766494751, - "time_taken": 7.382097613000042 + "time_taken": 7.351946285998565 }, "3": { "accuracy": 0.6038997769355774, - "time_taken": 7.2605028240004685 + "time_taken": 7.368035252991831 }, "4": { "accuracy": 0.6276999711990356, - "time_taken": 15.606078997999248 + "time_taken": 15.862566682990291 } }, "100": { "0": { "accuracy": 0.704299807548523, - "time_taken": 6.080028390000734 + "time_taken": 6.759049768996192 }, "1": { "accuracy": 0.7321001887321472, - "time_taken": 5.851846229999865 + "time_taken": 5.9133299559907755 }, "2": { "accuracy": 0.7243001461029053, - "time_taken": 9.87153262100037 + "time_taken": 9.780433213003562 }, "3": { "accuracy": 0.7279003262519836, - "time_taken": 3.8054748680006014 + "time_taken": 3.8765327779983636 }, "4": { "accuracy": 0.6853998899459839, - "time_taken": 7.10781191300066 + "time_taken": 7.18066198201268 } }, "500": { "0": { "accuracy": 0.849459171295166, - "time_taken": 4.452629182000237 + "time_taken": 4.7403871529968455 }, "1": { "accuracy": 0.8390424847602844, - "time_taken": 2.864639194999654 + "time_taken": 2.9221081489959033 }, "2": { "accuracy": 0.8586738705635071, - "time_taken": 4.056492333999813 + "time_taken": 4.059734050009865 }, "3": { "accuracy": 0.8433493971824646, - "time_taken": 3.1211573710006633 + "time_taken": 3.1853989910014207 }, "4": { "accuracy": 0.8508613705635071, - "time_taken": 4.570619415000692 + "time_taken": 4.598661220996291 } }, "1000": { "0": { "accuracy": 0.8806089758872986, - "time_taken": 4.099615649000043 + "time_taken": 4.128512479001074 }, "1": { "accuracy": 0.8755007982254028, - "time_taken": 2.84869260799951 + "time_taken": 2.8908094719954534 }, "2": { "accuracy": 0.8828125, - "time_taken": 2.8027847920002387 + "time_taken": 2.8451506499986863 }, "3": { "accuracy": 0.8731971383094788, - "time_taken": 2.9280455670004812 + "time_taken": 2.985724409998511 }, "4": { "accuracy": 0.8818109035491943, - "time_taken": 3.1732710930000394 + "time_taken": 3.223807489994215 } }, "5000": { "0": { "accuracy": 0.9258814454078674, - "time_taken": 4.808443182000701 + "time_taken": 4.811194872003398 }, "1": { "accuracy": 0.9238781929016113, - "time_taken": 3.5261693180000293 + "time_taken": 3.5581146879994776 }, "2": { "accuracy": 0.9277844429016113, - "time_taken": 3.4832477119998657 + "time_taken": 3.5328966109955218 }, "3": { "accuracy": 0.9291867017745972, - "time_taken": 5.071235248999983 + "time_taken": 5.172802444998524 }, "4": { "accuracy": 0.9294871687889099, - "time_taken": 3.7022924619996047 + "time_taken": 3.727755020998302 } } }, - "RPCholesky": { + "RP Cholesky": { "25": { "0": { - "accuracy": 0.4749898314476013, - "time_taken": 17.966321985999457 + "accuracy": 0.48649486899375916, + "time_taken": 15.06727149800281 }, "1": { - "accuracy": 0.5499197840690613, - "time_taken": 16.858777587000077 + "accuracy": 0.5369148850440979, + "time_taken": 16.84173994199955 }, "2": { - "accuracy": 0.5490198731422424, - "time_taken": 21.207069888999285 + "accuracy": 0.5941379070281982, + "time_taken": 28.234114010003395 }, "3": { - "accuracy": 0.4454781413078308, - "time_taken": 19.69648443999995 + "accuracy": 0.5332134366035461, + "time_taken": 22.014903976989444 }, "4": { - "accuracy": 0.5133053064346313, - "time_taken": 24.770281666999836 + "accuracy": 0.5119049549102783, + "time_taken": 24.88884191699617 } }, "50": { "0": { - "accuracy": 0.5790995359420776, - "time_taken": 7.9108207769995715 + "accuracy": 0.5625001192092896, + "time_taken": 15.471359351999126 }, "1": { - "accuracy": 0.6649996638298035, - "time_taken": 8.449880023000333 + "accuracy": 0.6616997122764587, + "time_taken": 9.69486076499743 }, "2": { - "accuracy": 0.6247993111610413, - "time_taken": 9.283834525000202 + "accuracy": 0.5726996660232544, + "time_taken": 7.714200211994466 }, "3": { - "accuracy": 0.6294994950294495, - "time_taken": 14.32392578100007 + "accuracy": 0.633499801158905, + "time_taken": 10.867528442002367 }, "4": { - "accuracy": 0.5775997638702393, - "time_taken": 13.211208218999673 + "accuracy": 0.6135995984077454, + "time_taken": 17.245890217003762 } }, "100": { "0": { - "accuracy": 0.7272999882698059, - "time_taken": 9.433186399000078 + "accuracy": 0.6975999474525452, + "time_taken": 13.063453700000537 }, "1": { - "accuracy": 0.6856998801231384, - "time_taken": 7.170746434999273 + "accuracy": 0.7179999947547913, + "time_taken": 7.070716499001719 }, "2": { - "accuracy": 0.6285001635551453, - "time_taken": 6.053287822000129 + "accuracy": 0.6369000673294067, + "time_taken": 5.612793920998229 }, "3": { - "accuracy": 0.6865997910499573, - "time_taken": 5.618571209000038 + "accuracy": 0.6892000436782837, + "time_taken": 6.446513513001264 }, "4": { - "accuracy": 0.6816001534461975, - "time_taken": 8.63684380199993 + "accuracy": 0.6939999461174011, + "time_taken": 8.73156723100692 } }, "500": { "0": { - "accuracy": 0.8495593070983887, - "time_taken": 5.951769608000177 + "accuracy": 0.8505609035491943, + "time_taken": 6.197642628001631 }, "1": { - "accuracy": 0.8297275900840759, - "time_taken": 3.830718762000288 + "accuracy": 0.8322315812110901, + "time_taken": 3.782923115999438 }, "2": { - "accuracy": 0.8172075152397156, - "time_taken": 4.784693309000431 + "accuracy": 0.8162059187889099, + "time_taken": 4.230790430010529 }, "3": { - "accuracy": 0.805588960647583, - "time_taken": 4.134177151000586 + "accuracy": 0.8254206776618958, + "time_taken": 4.489102895997348 }, "4": { - "accuracy": 0.8328325152397156, - "time_taken": 4.010456953999892 + "accuracy": 0.8451522588729858, + "time_taken": 3.9125116750074085 } }, "1000": { "0": { - "accuracy": 0.8742988705635071, - "time_taken": 5.9872287300004245 + "accuracy": 0.8621795177459717, + "time_taken": 6.000493390994961 }, "1": { - "accuracy": 0.8658854365348816, - "time_taken": 4.616141986999537 + "accuracy": 0.8517628312110901, + "time_taken": 4.43663791801373 }, "2": { - "accuracy": 0.8723958730697632, - "time_taken": 4.796533407000425 + "accuracy": 0.8653846383094788, + "time_taken": 4.807365334010683 }, "3": { - "accuracy": 0.8480569124221802, - "time_taken": 4.392059116000382 + "accuracy": 0.8575721383094788, + "time_taken": 4.603306793986121 }, "4": { - "accuracy": 0.8732972741127014, - "time_taken": 4.212848541999847 + "accuracy": 0.8751001954078674, + "time_taken": 4.6778737399872625 } }, "5000": { "0": { - "accuracy": 0.9250801205635071, - "time_taken": 29.899051755000073 + "accuracy": 0.9258814454078674, + "time_taken": 30.59640652500093 + }, + "1": { + "accuracy": 0.9270833730697632, + "time_taken": 28.73688672199205 + }, + "2": { + "accuracy": 0.9242788553237915, + "time_taken": 27.798969063995173 + }, + "3": { + "accuracy": 0.9295873641967773, + "time_taken": 28.526359000999946 + }, + "4": { + "accuracy": 0.9217748641967773, + "time_taken": 28.89112439600285 + } + } + }, + "Kernel Herding": { + "25": { + "0": { + "accuracy": 0.41976743936538696, + "time_taken": 16.602575227007037 + }, + "1": { + "accuracy": 0.3865547478199005, + "time_taken": 39.596862955004326 + }, + "2": { + "accuracy": 0.4191674292087555, + "time_taken": 27.549024899999495 + }, + "3": { + "accuracy": 0.3826534152030945, + "time_taken": 15.041240438004024 + }, + "4": { + "accuracy": 0.36784714460372925, + "time_taken": 31.221653377986513 + } + }, + "50": { + "0": { + "accuracy": 0.4702000617980957, + "time_taken": 9.662910286991973 + }, + "1": { + "accuracy": 0.49190017580986023, + "time_taken": 8.503036755995709 + }, + "2": { + "accuracy": 0.4674000144004822, + "time_taken": 6.310842988008517 + }, + "3": { + "accuracy": 0.46950021386146545, + "time_taken": 5.354004640001222 + }, + "4": { + "accuracy": 0.5017001032829285, + "time_taken": 8.116046239010757 + } + }, + "100": { + "0": { + "accuracy": 0.6386004090309143, + "time_taken": 7.6281732080096845 + }, + "1": { + "accuracy": 0.6136000752449036, + "time_taken": 5.189364300007583 + }, + "2": { + "accuracy": 0.6494002342224121, + "time_taken": 5.8485669609945035 + }, + "3": { + "accuracy": 0.639400064945221, + "time_taken": 5.13314476099913 + }, + "4": { + "accuracy": 0.626500129699707, + "time_taken": 5.724788473991794 + } + }, + "500": { + "0": { + "accuracy": 0.796875, + "time_taken": 5.436996539996471 + }, + "1": { + "accuracy": 0.7882612347602844, + "time_taken": 3.7532910439913394 + }, + "2": { + "accuracy": 0.7895632982254028, + "time_taken": 3.608708291008952 + }, + "3": { + "accuracy": 0.8052884936332703, + "time_taken": 3.843011178993038 + }, + "4": { + "accuracy": 0.7890625, + "time_taken": 4.171767199004535 + } + }, + "1000": { + "0": { + "accuracy": 0.8404447436332703, + "time_taken": 4.949075399010326 + }, + "1": { + "accuracy": 0.8525640964508057, + "time_taken": 3.195582884000032 + }, + "2": { + "accuracy": 0.848557710647583, + "time_taken": 3.2079661229945486 + }, + "3": { + "accuracy": 0.8504607677459717, + "time_taken": 3.3226542650081683 + }, + "4": { + "accuracy": 0.8540664911270142, + "time_taken": 3.8391184149950277 + } + }, + "5000": { + "0": { + "accuracy": 0.9304887652397156, + "time_taken": 7.827742292007315 }, "1": { "accuracy": 0.9318910241127014, - "time_taken": 28.640717678999863 + "time_taken": 5.016722931992263 }, "2": { - "accuracy": 0.9252804517745972, - "time_taken": 27.780483479000395 + "accuracy": 0.9341947436332703, + "time_taken": 5.358994267997332 }, "3": { - "accuracy": 0.9284855723381042, - "time_taken": 28.07316877300036 + "accuracy": 0.9357972741127014, + "time_taken": 6.045616682997206 }, "4": { - "accuracy": 0.9221754670143127, - "time_taken": 28.88361871699999 + "accuracy": 0.9310897588729858, + "time_taken": 5.193479452995234 } } }, - "KernelHerding": { + "Stein Thinning": { "25": { "0": { - "accuracy": 0.4633856415748596, - "time_taken": 21.417966922000232 + "accuracy": 0.3557422161102295, + "time_taken": 14.798573687992757 }, "1": { - "accuracy": 0.41996797919273376, - "time_taken": 21.351254336999773 + "accuracy": 0.37554997205734253, + "time_taken": 19.849432048999006 }, "2": { - "accuracy": 0.43827515840530396, - "time_taken": 18.99325110800055 + "accuracy": 0.365746408700943, + "time_taken": 15.417942133004544 }, "3": { - "accuracy": 0.43787533044815063, - "time_taken": 15.95066889699956 + "accuracy": 0.34683868288993835, + "time_taken": 14.829650210987893 }, "4": { - "accuracy": 0.4330735206604004, - "time_taken": 19.170973098000104 + "accuracy": 0.3714485466480255, + "time_taken": 16.72217624800396 } }, "50": { "0": { - "accuracy": 0.5270997285842896, - "time_taken": 8.63385688299968 + "accuracy": 0.44350001215934753, + "time_taken": 11.345535130007192 }, "1": { - "accuracy": 0.523399829864502, - "time_taken": 5.941639144999499 + "accuracy": 0.4461999833583832, + "time_taken": 13.185699121997459 }, "2": { - "accuracy": 0.5230996608734131, - "time_taken": 8.082218694000403 + "accuracy": 0.45210036635398865, + "time_taken": 10.52003458699619 }, "3": { - "accuracy": 0.5010999441146851, - "time_taken": 7.190972131000308 + "accuracy": 0.4285001754760742, + "time_taken": 11.961328809004044 }, "4": { - "accuracy": 0.5044997930526733, - "time_taken": 7.761757675999434 + "accuracy": 0.42170006036758423, + "time_taken": 10.88203769500251 } }, "100": { "0": { - "accuracy": 0.6518000364303589, - "time_taken": 6.5709225259997766 + "accuracy": 0.5332000851631165, + "time_taken": 9.059783407006762 }, "1": { - "accuracy": 0.5968999862670898, - "time_taken": 4.8999388629999885 + "accuracy": 0.4996998608112335, + "time_taken": 8.194069575009053 }, "2": { - "accuracy": 0.5978002548217773, - "time_taken": 5.410862422999344 + "accuracy": 0.5352000594139099, + "time_taken": 8.481501229995047 }, "3": { - "accuracy": 0.6083003878593445, - "time_taken": 4.623381014000188 + "accuracy": 0.4972999095916748, + "time_taken": 8.426191655002185 }, "4": { - "accuracy": 0.6084000468254089, - "time_taken": 5.331589719000476 + "accuracy": 0.4962000250816345, + "time_taken": 8.92351122800028 } }, "500": { "0": { - "accuracy": 0.7951722741127014, - "time_taken": 5.391921381999964 + "accuracy": 0.6190905570983887, + "time_taken": 9.206097094007418 }, "1": { - "accuracy": 0.8112980723381042, - "time_taken": 3.77097700899958 + "accuracy": 0.6015625, + "time_taken": 9.124566219994449 }, "2": { - "accuracy": 0.8093950152397156, - "time_taken": 3.864875149999534 + "accuracy": 0.6050680875778198, + "time_taken": 8.69083937999676 }, "3": { - "accuracy": 0.8002804517745972, - "time_taken": 3.4463949179998963 + "accuracy": 0.5941506624221802, + "time_taken": 9.4010932680103 }, "4": { - "accuracy": 0.7914663553237915, - "time_taken": 3.883257230999334 + "accuracy": 0.5866386294364929, + "time_taken": 9.061567828000989 } }, "1000": { "0": { - "accuracy": 0.846754789352417, - "time_taken": 5.342803989999993 + "accuracy": 0.611177921295166, + "time_taken": 9.522722403999069 }, "1": { - "accuracy": 0.859375, - "time_taken": 3.8688295120000475 + "accuracy": 0.6095753312110901, + "time_taken": 9.849882042995887 }, "2": { - "accuracy": 0.8543670177459717, - "time_taken": 3.8199609699995563 + "accuracy": 0.6153846383094788, + "time_taken": 10.069659473010688 }, "3": { - "accuracy": 0.8598757982254028, - "time_taken": 3.771100228999785 + "accuracy": 0.6005609035491943, + "time_taken": 10.171524318007869 }, "4": { - "accuracy": 0.8578726053237915, - "time_taken": 3.5506639280001764 + "accuracy": 0.6041666865348816, + "time_taken": 10.320493794002687 } }, "5000": { "0": { - "accuracy": 0.9211738705635071, - "time_taken": 8.47471116900033 + "accuracy": 0.6329126954078674, + "time_taken": 29.329851916991174 }, "1": { - "accuracy": 0.922776460647583, - "time_taken": 4.937656685000547 + "accuracy": 0.6323117017745972, + "time_taken": 29.234513267001603 }, "2": { - "accuracy": 0.9282852411270142, - "time_taken": 6.115159175999906 + "accuracy": 0.6167868375778198, + "time_taken": 33.62631028499163 }, "3": { - "accuracy": 0.9300881624221802, - "time_taken": 5.731699566000316 + "accuracy": 0.6232972741127014, + "time_taken": 29.677587008001865 }, "4": { - "accuracy": 0.9243789911270142, - "time_taken": 4.7779110970004695 + "accuracy": 0.6148838400840759, + "time_taken": 29.441241257009096 } } }, - "SteinThinning": { + "Kernel Thinning": { "25": { "0": { - "accuracy": 0.402760773897171, - "time_taken": 19.745292868000433 + "accuracy": 0.5086035132408142, + "time_taken": 86.50353899000038 }, "1": { - "accuracy": 0.36034372448921204, - "time_taken": 19.371978351000507 + "accuracy": 0.5096036195755005, + "time_taken": 27.584838315000525 }, "2": { - "accuracy": 0.3401360809803009, - "time_taken": 14.285965998000393 + "accuracy": 0.4775909185409546, + "time_taken": 25.87109093600884 }, "3": { - "accuracy": 0.3896558880805969, - "time_taken": 16.36114198499945 + "accuracy": 0.4638856053352356, + "time_taken": 18.325479343999177 }, "4": { - "accuracy": 0.3565424978733063, - "time_taken": 20.385888099999647 + "accuracy": 0.37675046920776367, + "time_taken": 14.236614166002255 } }, "50": { "0": { - "accuracy": 0.4324001669883728, - "time_taken": 13.217264007999802 + "accuracy": 0.630199670791626, + "time_taken": 46.260885070994846 }, "1": { - "accuracy": 0.42600005865097046, - "time_taken": 9.966807724000319 + "accuracy": 0.6344994902610779, + "time_taken": 10.804280167998513 }, "2": { - "accuracy": 0.4104001522064209, - "time_taken": 10.416085636999924 + "accuracy": 0.5950995683670044, + "time_taken": 12.697421856006258 }, "3": { - "accuracy": 0.41020023822784424, - "time_taken": 10.325945335999677 + "accuracy": 0.5381998419761658, + "time_taken": 13.086974981997628 }, "4": { - "accuracy": 0.3909001350402832, - "time_taken": 13.138384072000008 + "accuracy": 0.5856001377105713, + "time_taken": 7.364729672000976 } }, "100": { "0": { - "accuracy": 0.4612000286579132, - "time_taken": 9.312922237000748 + "accuracy": 0.7089999914169312, + "time_taken": 27.71084121799504 }, "1": { - "accuracy": 0.469699889421463, - "time_taken": 8.583355914000094 + "accuracy": 0.7250999212265015, + "time_taken": 6.485319369006902 }, "2": { - "accuracy": 0.4544999897480011, - "time_taken": 8.941393133999554 + "accuracy": 0.735200047492981, + "time_taken": 6.369364660000429 }, "3": { - "accuracy": 0.4674999415874481, - "time_taken": 8.999052006000056 + "accuracy": 0.6910000443458557, + "time_taken": 4.998287808994064 }, "4": { - "accuracy": 0.451499879360199, - "time_taken": 9.512783743 + "accuracy": 0.7174001336097717, + "time_taken": 5.89240234499448 } }, "500": { "0": { - "accuracy": 0.5759214758872986, - "time_taken": 8.666155873000207 + "accuracy": 0.8403445482254028, + "time_taken": 10.419982768013142 }, "1": { - "accuracy": 0.5520833134651184, - "time_taken": 8.722324781999305 + "accuracy": 0.8527644276618958, + "time_taken": 3.7628963680035667 }, "2": { - "accuracy": 0.5831330418586731, - "time_taken": 9.471579476999977 + "accuracy": 0.858473539352417, + "time_taken": 3.6505097879999084 }, "3": { - "accuracy": 0.5356570482254028, - "time_taken": 9.300596547000168 + "accuracy": 0.8390424847602844, + "time_taken": 3.832862326002214 }, "4": { - "accuracy": 0.5519831776618958, - "time_taken": 9.42204293100076 + "accuracy": 0.8510617017745972, + "time_taken": 3.9164835209958255 } }, "1000": { "0": { - "accuracy": 0.5844351053237915, - "time_taken": 10.180833551000433 + "accuracy": 0.8800080418586731, + "time_taken": 10.038269149008556 }, "1": { - "accuracy": 0.5734174847602844, - "time_taken": 9.615570485999342 + "accuracy": 0.8777043223381042, + "time_taken": 3.7039374739979394 }, "2": { - "accuracy": 0.5765224695205688, - "time_taken": 9.886270799000158 + "accuracy": 0.8801081776618958, + "time_taken": 3.2720982680039015 }, "3": { - "accuracy": 0.5474759936332703, - "time_taken": 9.45628960400063 + "accuracy": 0.879807710647583, + "time_taken": 3.7813302020076662 }, "4": { - "accuracy": 0.5767227411270142, - "time_taken": 13.535986196000522 + "accuracy": 0.8847155570983887, + "time_taken": 3.6361338120041182 } }, "5000": { "0": { - "accuracy": 0.6347155570983887, - "time_taken": 29.896120647000316 + "accuracy": 0.9284855723381042, + "time_taken": 16.11581551800191 }, "1": { - "accuracy": 0.6311097741127014, - "time_taken": 29.06243791899942 + "accuracy": 0.9262820482254028, + "time_taken": 5.4665557440021075 }, "2": { - "accuracy": 0.650240421295166, - "time_taken": 33.30571946100008 + "accuracy": 0.9265825152397156, + "time_taken": 6.757973326995852 }, "3": { - "accuracy": 0.6276041865348816, - "time_taken": 30.51324089100035 + "accuracy": 0.9316906929016113, + "time_taken": 6.069867120997515 }, "4": { - "accuracy": 0.626802921295166, - "time_taken": 29.449209794000126 + "accuracy": 0.9258814454078674, + "time_taken": 5.646533480001381 } } }, - "KernelThinning": { + "Compress++": { "25": { "0": { - "accuracy": 0.4526808559894562, - "time_taken": 79.91361540300022 + "accuracy": 0.5006005764007568, + "time_taken": 48.42374569299864 }, "1": { - "accuracy": 0.40636226534843445, - "time_taken": 11.441792544000236 + "accuracy": 0.42466992139816284, + "time_taken": 15.439281265003956 }, "2": { - "accuracy": 0.48499375581741333, - "time_taken": 28.38083182599985 + "accuracy": 0.4636856019496918, + "time_taken": 27.897809760004748 }, "3": { - "accuracy": 0.418967604637146, - "time_taken": 31.02636636399984 + "accuracy": 0.4906962215900421, + "time_taken": 13.080031949997647 }, "4": { - "accuracy": 0.4386753439903259, - "time_taken": 22.148064952000823 + "accuracy": 0.4217684864997864, + "time_taken": 20.284750368999084 } }, "50": { "0": { - "accuracy": 0.6550991535186768, - "time_taken": 40.07802116200037 + "accuracy": 0.5905998349189758, + "time_taken": 25.081396187990322 }, "1": { - "accuracy": 0.6222995519638062, - "time_taken": 7.831840456000464 + "accuracy": 0.608699381351471, + "time_taken": 6.843037225000444 }, "2": { - "accuracy": 0.638399600982666, - "time_taken": 10.783056123999813 + "accuracy": 0.58219975233078, + "time_taken": 7.7891707769886125 }, "3": { - "accuracy": 0.6058999300003052, - "time_taken": 9.712099656000646 + "accuracy": 0.4834998846054077, + "time_taken": 7.2219549550063675 }, "4": { - "accuracy": 0.5935994386672974, - "time_taken": 11.725060616000519 + "accuracy": 0.5665996670722961, + "time_taken": 9.369507609997527 } }, "100": { "0": { - "accuracy": 0.7074002027511597, - "time_taken": 23.023189279999315 + "accuracy": 0.7213001251220703, + "time_taken": 17.787696682993555 }, "1": { - "accuracy": 0.6904999613761902, - "time_taken": 5.541293954999674 + "accuracy": 0.7342995405197144, + "time_taken": 3.8416210310097085 }, "2": { - "accuracy": 0.7206001281738281, - "time_taken": 5.422012566999911 + "accuracy": 0.6959000825881958, + "time_taken": 4.702277585005504 }, "3": { - "accuracy": 0.6836000680923462, - "time_taken": 3.8105385219996606 + "accuracy": 0.7541998028755188, + "time_taken": 4.814263307998772 }, "4": { - "accuracy": 0.7140999436378479, - "time_taken": 5.527283147000162 + "accuracy": 0.7210000157356262, + "time_taken": 4.949811800004682 } }, "500": { "0": { - "accuracy": 0.8555689454078674, - "time_taken": 12.997218095999415 + "accuracy": 0.8381410241127014, + "time_taken": 7.494220376000158 }, "1": { - "accuracy": 0.8521634936332703, - "time_taken": 3.805464752000262 + "accuracy": 0.8441506624221802, + "time_taken": 3.3129628419992514 }, "2": { - "accuracy": 0.8616787195205688, - "time_taken": 4.052976539000156 + "accuracy": 0.8579727411270142, + "time_taken": 3.4562966030061943 }, "3": { - "accuracy": 0.859375, - "time_taken": 3.713409908000358 + "accuracy": 0.8535656929016113, + "time_taken": 3.3829493960074615 }, "4": { - "accuracy": 0.8414463400840759, - "time_taken": 3.7632819319996997 + "accuracy": 0.8521634936332703, + "time_taken": 3.694734361008159 } }, "1000": { "0": { - "accuracy": 0.8792067170143127, - "time_taken": 9.468509635999908 + "accuracy": 0.8771033883094788, + "time_taken": 7.094734907004749 }, "1": { - "accuracy": 0.8872195482254028, - "time_taken": 3.8311882420002803 + "accuracy": 0.8756009936332703, + "time_taken": 3.0448469719995046 }, "2": { - "accuracy": 0.8777043223381042, - "time_taken": 3.494454018000397 + "accuracy": 0.8830128312110901, + "time_taken": 3.1766963270056294 }, "3": { - "accuracy": 0.8869190812110901, - "time_taken": 3.8445811410001625 + "accuracy": 0.8806089758872986, + "time_taken": 3.6595106110034976 }, "4": { - "accuracy": 0.8827123641967773, - "time_taken": 3.8723056460003136 + "accuracy": 0.8797075152397156, + "time_taken": 3.1855167699977756 } }, "5000": { "0": { - "accuracy": 0.9207732677459717, - "time_taken": 14.704426218000663 + "accuracy": 0.9220753312110901, + "time_taken": 6.997968723997474 + }, + "1": { + "accuracy": 0.9252804517745972, + "time_taken": 4.05370223299542 + }, + "2": { + "accuracy": 0.9270833730697632, + "time_taken": 4.276571049995255 + }, + "3": { + "accuracy": 0.9282852411270142, + "time_taken": 4.490389550992404 + }, + "4": { + "accuracy": 0.9259815812110901, + "time_taken": 4.635322599002393 + } + } + }, + "Probabilistic Iterative Herding": { + "25": { + "0": { + "accuracy": 0.5539213418960571, + "time_taken": 22.00237893400481 + }, + "1": { + "accuracy": 0.5617246031761169, + "time_taken": 30.902002311995602 + }, + "2": { + "accuracy": 0.5159062147140503, + "time_taken": 13.908021073002601 + }, + "3": { + "accuracy": 0.4784916043281555, + "time_taken": 30.6841775149951 + }, + "4": { + "accuracy": 0.5513207316398621, + "time_taken": 23.426923535997048 + } + }, + "50": { + "0": { + "accuracy": 0.6849998235702515, + "time_taken": 13.994797082996229 + }, + "1": { + "accuracy": 0.6985998153686523, + "time_taken": 12.24240016200929 + }, + "2": { + "accuracy": 0.6409992575645447, + "time_taken": 10.90571705000184 + }, + "3": { + "accuracy": 0.6208995580673218, + "time_taken": 9.279130113995052 + }, + "4": { + "accuracy": 0.6784992814064026, + "time_taken": 8.457788833999075 + } + }, + "100": { + "0": { + "accuracy": 0.7381998300552368, + "time_taken": 11.28686753900547 }, "1": { - "accuracy": 0.9328926205635071, - "time_taken": 6.878256106000663 + "accuracy": 0.7427999973297119, + "time_taken": 5.407744587995694 }, "2": { - "accuracy": 0.9287860989570618, - "time_taken": 6.707938498000658 + "accuracy": 0.7714999914169312, + "time_taken": 7.464625088003231 }, "3": { - "accuracy": 0.9293870329856873, - "time_taken": 4.895929524000167 + "accuracy": 0.7434998154640198, + "time_taken": 6.5014143439912 }, "4": { + "accuracy": 0.7387998104095459, + "time_taken": 6.165277772000991 + } + }, + "500": { + "0": { + "accuracy": 0.8561698794364929, + "time_taken": 8.956057899005827 + }, + "1": { + "accuracy": 0.858473539352417, + "time_taken": 4.108647701999871 + }, + "2": { + "accuracy": 0.8372396230697632, + "time_taken": 4.213687284005573 + }, + "3": { + "accuracy": 0.8531650900840759, + "time_taken": 5.0821179729973665 + }, + "4": { + "accuracy": 0.849459171295166, + "time_taken": 4.7921027800039155 + } + }, + "1000": { + "0": { + "accuracy": 0.8831129670143127, + "time_taken": 9.668020572993555 + }, + "1": { + "accuracy": 0.8720953464508057, + "time_taken": 4.814860455997405 + }, + "2": { + "accuracy": 0.8809094429016113, + "time_taken": 4.780255288002081 + }, + "3": { + "accuracy": 0.8863181471824646, + "time_taken": 4.911866419992293 + }, + "4": { + "accuracy": 0.8823117017745972, + "time_taken": 5.024809098991682 + } + }, + "5000": { + "0": { + "accuracy": 0.922776460647583, + "time_taken": 23.844807379995473 + }, + "1": { + "accuracy": 0.9272836446762085, + "time_taken": 17.11210143000062 + }, + "2": { + "accuracy": 0.9250801205635071, + "time_taken": 16.5623230929923 + }, + "3": { "accuracy": 0.9291867017745972, - "time_taken": 5.907913719000135 + "time_taken": 17.20165599399479 + }, + "4": { + "accuracy": 0.9233773946762085, + "time_taken": 15.688548125996022 } } } diff --git a/benchmark/mnist_benchmark_visualiser.py b/benchmark/mnist_benchmark_visualiser.py index 253ed235..ad022d0e 100644 --- a/benchmark/mnist_benchmark_visualiser.py +++ b/benchmark/mnist_benchmark_visualiser.py @@ -162,6 +162,8 @@ def plot_performance( bar_width = 0.8 / n_algorithms # Divide available space for bars index = np.arange(len(coreset_sizes)) # x positions for coreset sizes + plt.figure(figsize=(12, 8)) # Bigger figure size + for i, algo in enumerate(stats): # Plot the bars for mean values plt.bar( @@ -170,9 +172,12 @@ def plot_performance( bar_width, label=algo, color=f"C{i}", - alpha=0.7, + alpha=0.8, + edgecolor="black", + linewidth=1.5, ) + # Add error bars with a larger capsize plt.errorbar( index + i * bar_width, stats[algo]["means"], @@ -182,18 +187,19 @@ def plot_performance( ], fmt="none", ecolor="black", - capsize=5, + capsize=7, # Larger cap size for better visibility alpha=0.9, + elinewidth=2, # Thicker error bars ) - # Overlay individual points as dots + # Overlay individual points as larger dots for j, size in enumerate(coreset_sizes): x_positions = ( index[j] + i * bar_width + np.random.uniform( - -0.01 * bar_width, - 0.01 * bar_width, + -0.02 * bar_width, + 0.02 * bar_width, len(stats[algo]["points"][size]), ) ) @@ -201,21 +207,38 @@ def plot_performance( x_positions, stats[algo]["points"][size], color=f"C{i}", - s=10, + s=40, # Larger dots for better visibility + edgecolor="black", + linewidth=0.8, + alpha=0.8, ) - # Add labels, titles, and other plot formatting - plt.xlabel("Coreset Size") - plt.ylabel(ylabel) + # Add labels, titles, and formatting + plt.ylabel(ylabel, fontsize=20, fontweight="bold") + plt.xlabel("Coreset Size", fontsize=20, fontweight="bold") + if log_scale: plt.yscale("log") - plt.title(title) + + plt.title(title, fontsize=24, fontweight="bold") plt.xticks( index + bar_width * (n_algorithms - 1) / 2, [str(size) for size in coreset_sizes], + fontsize=18, ) - plt.legend() - plt.grid(True, linestyle="--", alpha=0.7) + plt.yticks(fontsize=18) + + # Enhanced legend styling + plt.legend( + loc="lower center", + bbox_to_anchor=(0.5, -0.25), + ncol=(n_algorithms + 1) // 2, + fontsize=18, + frameon=True, + edgecolor="black", + ) + + plt.grid(True, linestyle="--", alpha=0.5) plt.tight_layout() @@ -252,16 +275,6 @@ def main() -> None: "Algorithm Performance (Accuracy) for Different Coreset Sizes", ) - plt.figtext( - 0.5, - 0.01, - "Plot showing the mean performance of algorithms with error bars" - " representing min-max ranges", - wrap=True, - horizontalalignment="center", - fontsize=8, - ) - plt.show() # Plot time taken results @@ -275,12 +288,12 @@ def main() -> None: plt.figtext( 0.5, - 0.01, + 0.91, "Plot showing the mean time taken to generate coresets and train MNIST" "classifier with coreset sizes with error bars representing min-max ranges", wrap=True, horizontalalignment="center", - fontsize=8, + fontsize=12, ) plt.show() @@ -297,12 +310,12 @@ def main() -> None: plt.figtext( 0.5, - 0.01, + 0.91, "Plot showing the mean time taken to generate coresets of different" " coreset sizes with error bars representing min-max ranges", wrap=True, horizontalalignment="center", - fontsize=8, + fontsize=12, ) plt.show() diff --git a/benchmark/mnist_time_results.json b/benchmark/mnist_time_results.json index 4406fc12..1af9417d 100644 --- a/benchmark/mnist_time_results.json +++ b/benchmark/mnist_time_results.json @@ -1,187 +1,261 @@ { - "RandomSample": { + "Random Sample": { "25": { - "0": 0.6588367329986795, - "1": 0.0016759669997554738, - "2": 0.001854901998740388, - "3": 0.0016810239994811127, - "4": 0.0016000750001694541 + "0": 0.7094588610000301, + "1": 0.0017291089999389442, + "2": 0.001685944999962885, + "3": 0.001793641000176649, + "4": 0.0018410610000501038 }, "50": { - "0": 0.6135869649988308, - "1": 0.0016016179997677682, - "2": 0.001721464001093409, - "3": 0.00153965900062758, - "4": 0.0016626109991193516 + "0": 0.6013920729999427, + "1": 0.0017869449999352582, + "2": 0.0019502580000789749, + "3": 0.0016377110000576067, + "4": 0.0017066830000658229 }, "100": { - "0": 0.6118492949990468, - "1": 0.0015750410002510762, - "2": 0.0018720730004133657, - "3": 0.0015939309996610973, - "4": 0.0015289340008166619 + "0": 0.6137222579999388, + "1": 0.0016203530000211686, + "2": 0.0018824950000180252, + "3": 0.001643035999904896, + "4": 0.001717244000019491 }, "500": { - "0": 0.6197238580007252, - "1": 0.0018075179996230872, - "2": 0.001730617999783135, - "3": 0.0015197210013866425, - "4": 0.0016761769984441344 + "0": 0.6200576559999718, + "1": 0.001770888000010018, + "2": 0.0017450970001391397, + "3": 0.0016370069999993575, + "4": 0.001675381000040943 }, "1000": { - "0": 0.6255511849994946, - "1": 0.0015996750007616356, - "2": 0.001622727999347262, - "3": 0.0015904940009932034, - "4": 0.0015960620003170334 + "0": 0.6196925570000076, + "1": 0.002092210000000705, + "2": 0.0017158939999717404, + "3": 0.0015659470000173314, + "4": 0.001732565000111208 } }, - "RPCholesky": { + "RP Cholesky": { "25": { - "0": 1.519872093000231, - "1": 0.009014817998831859, - "2": 0.009139535000940668, - "3": 0.008957120000559371, - "4": 0.008898485999452532 + "0": 1.5147350790000473, + "1": 0.009266948999993474, + "2": 0.009186939999835886, + "3": 0.008905391000098462, + "4": 0.008803853000017625 }, "50": { - "0": 1.458239367000715, - "1": 0.017449393999413587, - "2": 0.017646361000515753, - "3": 0.017497655000624945, - "4": 0.01741153100010706 + "0": 1.4470037309999952, + "1": 0.01754613100001734, + "2": 0.017912022000018624, + "3": 0.017566263000162508, + "4": 0.017581992000032187 }, "100": { - "0": 1.48360984700048, - "1": 0.03753473100005067, - "2": 0.037234362000162946, - "3": 0.037164757999562426, - "4": 0.037121964000107255 + "0": 1.4746420580000859, + "1": 0.03754237800001192, + "2": 0.03769687899989549, + "3": 0.037440069000012954, + "4": 0.03766440699996565 }, "500": { - "0": 1.9741088339997077, - "1": 0.3905957330007368, - "2": 0.3912712890014518, - "3": 0.3908946850006032, - "4": 0.39045989899932465 + "0": 2.1145162100000334, + "1": 0.39179865499988864, + "2": 0.3916123300000436, + "3": 0.39198892800004614, + "4": 0.39262377000000015 }, "1000": { - "0": 2.7680352509996737, - "1": 1.2777003790015442, - "2": 1.278581608999957, - "3": 1.27806329699888, - "4": 1.2789005979993817 + "0": 2.795978721000097, + "1": 1.2802431059999435, + "2": 1.2806992569999238, + "3": 1.2820121699999163, + "4": 1.2820419469999251 } }, - "KernelHerding": { + "Kernel Herding": { "25": { - "0": 2.8221302519996243, - "1": 0.2740416950000508, - "2": 0.2751288489998842, - "3": 0.27401620499949786, - "4": 0.2768986720002431 + "0": 2.610353960999987, + "1": 0.28118345199993655, + "2": 0.285841161999997, + "3": 0.27556166899989876, + "4": 0.2756238229999326 }, "50": { - "0": 1.577164325000922, - "1": 0.27728562999982387, - "2": 0.2791873059995851, - "3": 0.27608360599879234, - "4": 0.2759777829996892 + "0": 1.5984972850000077, + "1": 0.28247100600003705, + "2": 0.288643810999929, + "3": 0.28012494099994, + "4": 0.2785548000001654 }, "100": { - "0": 1.5832591270009289, - "1": 0.2820007230002375, - "2": 0.28537204100030067, - "3": 0.2851871380007651, - "4": 0.2826830610010802 + "0": 1.9940642040000967, + "1": 0.2899575900000855, + "2": 0.2972997360000136, + "3": 0.2834464029999708, + "4": 0.28465508800013595 }, "500": { - "0": 1.7189907629999652, - "1": 0.3274141089987097, - "2": 0.3350357380004425, - "3": 0.3309756380003819, - "4": 0.3338158980004664 + "0": 1.7415117790000068, + "1": 0.33662594000008994, + "2": 0.3445417009997982, + "3": 0.33440655700019306, + "4": 0.33455426499995156 }, "1000": { - "0": 1.8001215259992023, - "1": 0.3800924019997183, - "2": 0.3808283049984311, - "3": 0.3794397999990906, - "4": 0.3793001079993701 + "0": 1.8535430869999345, + "1": 0.3871720610000011, + "2": 0.3925284009999359, + "3": 0.38465819700013526, + "4": 0.3837121839999327 } }, - "SteinThinning": { + "Stein Thinning": { "25": { - "0": 4.062421253998764, - "1": 3.7067944970003737, - "2": 3.592493141999512, - "3": 3.622419968000031, - "4": 3.645126594999965 + "0": 3.8074047820000487, + "1": 3.824364295999999, + "2": 3.6565696080001544, + "3": 3.6072808989999885, + "4": 3.6074516559999665 }, "50": { - "0": 3.842947594999714, - "1": 3.841391066000142, - "2": 3.75092935199973, - "3": 3.7667867290001595, - "4": 3.7816334729996015 + "0": 4.19630774999996, + "1": 3.923172228999988, + "2": 3.862083046999942, + "3": 3.7715675970000575, + "4": 3.7609794219999912 }, "100": { - "0": 4.318612201999713, - "1": 3.9970832929993776, - "2": 3.9448298860006616, - "3": 3.9489083890002803, - "4": 3.923231038999802 + "0": 3.9830489679999346, + "1": 4.064505925999924, + "2": 4.000870295999903, + "3": 3.931446998000183, + "4": 3.9275171780000164 }, "500": { - "0": 5.439173460999882, - "1": 5.453762462000668, - "2": 5.412860956999793, - "3": 5.417045615000461, - "4": 5.401182251000137 + "0": 5.851763790000064, + "1": 5.488356060000001, + "2": 5.444121625000207, + "3": 5.364347303999921, + "4": 5.428865189000135 }, "1000": { - "0": 7.2683715949988255, - "1": 9.998894977999953, - "2": 6.865873652999653, - "3": 6.937165109000489, - "4": 6.851976205998653 + "0": 6.939637783999956, + "1": 10.965542330000062, + "2": 6.903962380999928, + "3": 6.899562562000028, + "4": 6.997262595000166 } }, - "KernelThinning": { + "Kernel Thinning": { "25": { - "0": 57.40046906799944, - "1": 0.4691025020001689, - "2": 0.4655119300005026, - "3": 0.4667137439992075, - "4": 0.4651397059988085 + "0": 66.63043982800002, + "1": 0.4887004889999389, + "2": 0.4734878320000462, + "3": 0.47251907500003654, + "4": 0.47228892800012545 }, "50": { - "0": 31.640804945000127, - "1": 0.4824182379998092, - "2": 0.4758363860000827, - "3": 0.4767543710004247, - "4": 0.4776206539991108 + "0": 33.34872766800004, + "1": 0.5037238669999624, + "2": 0.488602832999959, + "3": 0.4856810139999652, + "4": 0.48470954399999755 }, "100": { - "0": 19.31581526300033, - "1": 0.4937812940006552, - "2": 0.4880024880003475, - "3": 0.49104342400096357, - "4": 0.48902785900099843 + "0": 21.044754845999933, + "1": 0.5149621710000929, + "2": 0.5004036840000481, + "3": 0.49941481700011536, + "4": 0.49638310099999217 }, "500": { - "0": 6.845564752999053, - "1": 0.47625808600059827, - "2": 0.47526046200073324, - "3": 0.47945132499989995, - "4": 0.4767048789999535 + "0": 7.190681866999967, + "1": 0.5032470499999135, + "2": 0.48689451700010977, + "3": 0.48538076199997704, + "4": 0.48296712800015484 }, "1000": { - "0": 6.7132718520006165, - "1": 0.5458353630001511, - "2": 0.5415765449997707, - "3": 0.5430132919991593, - "4": 0.5429429139985587 + "0": 6.903874932000008, + "1": 0.5668272540000316, + "2": 0.5528795729999274, + "3": 0.5505331360000127, + "4": 0.5481448209998234 + } + }, + "Compress++": { + "25": { + "0": 29.412537138999937, + "1": 0.05838682499995684, + "2": 0.059164918999840665, + "3": 0.0580135689999679, + "4": 0.057627395000054094 + }, + "50": { + "0": 18.233435394000026, + "1": 0.06534597700010636, + "2": 0.06501679099983448, + "3": 0.06366876599986426, + "4": 0.06367448700007117 + }, + "100": { + "0": 8.529528905000006, + "1": 0.0648159620000115, + "2": 0.06605522800009567, + "3": 0.06563545900007739, + "4": 0.06551302200000464 + }, + "500": { + "0": 4.349865860000023, + "1": 0.0864839649999567, + "2": 0.08642254599999433, + "3": 0.08871303000000808, + "4": 0.08833459100014807 + }, + "1000": { + "0": 3.6908661520000123, + "1": 0.10315516699995442, + "2": 0.1031596630000422, + "3": 0.1060016710000582, + "4": 0.10459229200000664 + } + }, + "Probabilistic Iterative Herding": { + "25": { + "0": 4.567380666000076, + "1": 0.3798184410001113, + "2": 0.35353611400000773, + "3": 0.3526079970001774, + "4": 0.35121752999998535 + }, + "50": { + "0": 4.6527845999999045, + "1": 0.41235254699995494, + "2": 0.3860566419998577, + "3": 0.38475866199996744, + "4": 0.38376728100001856 + }, + "100": { + "0": 4.8177640440000005, + "1": 0.47743541500005904, + "2": 0.45443064400001276, + "3": 0.45361023100008424, + "4": 0.4517581079999218 + }, + "500": { + "0": 5.6638604600000235, + "1": 1.0731660229999989, + "2": 1.0509365340001295, + "3": 1.0473623349998888, + "4": 1.0499444979998316 + }, + "1000": { + "0": 6.847450321999986, + "1": 1.8655788040000516, + "2": 1.843780653000067, + "3": 1.8360947809999288, + "4": 1.834880068000075 } } } diff --git a/benchmark/pounce_benchmark.py b/benchmark/pounce_benchmark.py index 110d471b..77fe8e5d 100644 --- a/benchmark/pounce_benchmark.py +++ b/benchmark/pounce_benchmark.py @@ -29,10 +29,37 @@ import numpy as np import umap from jax import random +from matplotlib import pyplot as plt -from coreax.benchmark_util import get_solver_name, initialise_solvers +from coreax.benchmark_util import initialise_solvers from coreax.data import Data -from coreax.solvers import MapReduce + + +def plot_selected_frames(umap_data, selected_indices, action_frames, solver_name): + """ + Plot the selected frames and action frames on a bar chart. + + :param umap_data: The UMAP-transformed data. + :param selected_indices: Indices of the selected frames. + :param action_frames: Indices of the action frames. + :param solver_name: The name of the solver used. + """ + x = np.arange(len(umap_data)) + y = np.zeros(len(umap_data)) + y[selected_indices] = 1.0 + + z = np.zeros(len(umap_data)) + z[jnp.intersect1d(selected_indices, action_frames)] = 1.0 + + plt.figure(figsize=(20, 3)) + plt.bar(x, y, alpha=0.5, label="Selected Frames") + plt.bar(x, z, label="Action Frames") + plt.xlabel("Frame", fontsize=18) + plt.ylabel("Chosen", fontsize=18) + plt.title(f"Selected Frames for {solver_name}", fontsize=24, fontweight="bold") + plt.legend() + plt.tight_layout() + plt.show() def benchmark_coreset_algorithms( @@ -62,16 +89,13 @@ def benchmark_coreset_algorithms( umap_model = umap.UMAP(densmap=True, n_components=25) umap_data = jnp.asarray(umap_model.fit_transform(reshaped_data)) + print("umap_data_shape", umap_data.shape) - solver_factories = initialise_solvers(Data(umap_data), random.PRNGKey(45)) - for solver_creator in solver_factories: + solver_factories = initialise_solvers( + Data(umap_data), random.PRNGKey(45), cpp_oversampling_factor=3 + ) + for solver_name, solver_creator in solver_factories.items(): solver = solver_creator(coreset_size) - - # There is no need to use MapReduce as the data-size is small - if isinstance(solver, MapReduce): - solver = solver.base_solver - - solver_name = get_solver_name(solver_creator) data = Data(umap_data) start_time = time.perf_counter() @@ -80,13 +104,19 @@ def benchmark_coreset_algorithms( selected_indices = np.sort(np.asarray(coreset.unweighted_indices)) - # Extract corresponding frames from original data and save GIF coreset_frames = raw_data[selected_indices] output_gif_path = out_dir / f"{solver_name}_coreset.gif" imageio.v3.imwrite(output_gif_path, coreset_frames, loop=0) print(f"Saved {solver_name} coreset GIF to {output_gif_path}") print(f"time taken: {solver_name:<25} {duration:<30.4f}") + plot_selected_frames( + umap_data=umap_data, + selected_indices=selected_indices, + action_frames=np.arange(63, 85), + solver_name=solver_name, + ) + if __name__ == "__main__": benchmark_coreset_algorithms() diff --git a/coreax/benchmark_util.py b/coreax/benchmark_util.py index e30e036c..1e52e8e7 100644 --- a/coreax/benchmark_util.py +++ b/coreax/benchmark_util.py @@ -21,6 +21,7 @@ """ from collections.abc import Callable +from typing import Optional, Union import jax.numpy as jnp import numpy as np @@ -30,6 +31,8 @@ from coreax.kernels import SquaredExponentialKernel, SteinKernel, median_heuristic from coreax.score_matching import KernelDensityMatching from coreax.solvers import ( + CompressPlusPlus, + IterativeKernelHerding, KernelHerding, KernelThinning, MapReduce, @@ -70,9 +73,12 @@ def calculate_delta(n: int) -> Float[Array, "1"]: return jnp.array(1 / n) -def initialise_solvers( - train_data_umap: Data, key: KeyArrayLike -) -> list[Callable[[int], Solver]]: +def initialise_solvers( # noqa: C901 + train_data_umap: Data, + key: KeyArrayLike, + cpp_oversampling_factor: int, + leaf_size: Optional[int] = None, +) -> dict[str, Callable[[int], Solver]]: """ Initialise and return a list of solvers for various coreset algorithms. @@ -85,7 +91,11 @@ def initialise_solvers( :param train_data_umap: The UMAP-transformed training data used for length scale estimation for ``SquareExponentialKernel``. :param key: The random key for initialising random solvers. - :return: A list of solvers functions for different coreset algorithms. + :param cpp_oversampling_factor: The oversampling factor for `Compress++`. + :param leaf_size: The leaf size to be used in `MapReduce` solvers. If not provided + (i.e., `None`), `MapReduce` solvers will not be used. + :return: A dictionary where the keys are solver names and the values are + corresponding solver functions for different coreset algorithms. """ # Set up kernel using median heuristic num_data_points = len(train_data_umap) @@ -97,16 +107,16 @@ def initialise_solvers( kernel = SquaredExponentialKernel(length_scale=length_scale) sqrt_kernel = kernel.get_sqrt_kernel(16) - def _get_thinning_solver(_size: int) -> MapReduce: + def _get_thinning_solver(_size: int) -> Union[KernelThinning, MapReduce]: """ - Set up KernelThinning to use ``MapReduce``. + Set up kernel thinning solver. - Create a KernelThinning solver with the specified size and return - it along with a MapReduce object for reducing a large dataset like - MNIST dataset. + If the `leaf_size` is provided, the solver uses ``MapReduce`` to reduce + datasets. :param _size: The size of the coreset to be generated. - :return: MapReduce solver with KernelThinning as the base solver. + :return: A `KernelThinning` solver if `leaf_size` is `None`, otherwise a + `MapReduce` solver with `KernelThinning` as the base solver. """ thinning_solver = KernelThinning( coreset_size=_size, @@ -115,33 +125,36 @@ def _get_thinning_solver(_size: int) -> MapReduce: delta=calculate_delta(num_data_points).item(), sqrt_kernel=sqrt_kernel, ) + if leaf_size is None: + return thinning_solver + return MapReduce(thinning_solver, leaf_size=leaf_size) - return MapReduce(thinning_solver, leaf_size=3 * _size) - - def _get_herding_solver(_size: int) -> MapReduce: + def _get_herding_solver(_size: int) -> Union[KernelHerding, MapReduce]: """ - Set up KernelHerding to use ``MapReduce``. + Set up kernel herding solver. - Create a KernelHerding solver with the specified size and return - it along with a MapReduce object for reducing a large dataset like - MNIST dataset. + If the `leaf_size` is provided, the solver uses ``MapReduce`` to reduce + datasets. :param _size: The size of the coreset to be generated. - :return: MapReduce solver with KernelHerding as the base solver. + :return: A `KernelHerding` solver if `leaf_size` is `None`, otherwise a + `MapReduce` solver with `KernelHerding` as the base solver. """ herding_solver = KernelHerding(_size, kernel) - return MapReduce(herding_solver, leaf_size=3 * _size) + if leaf_size is None: + return herding_solver + return MapReduce(herding_solver, leaf_size=leaf_size) - def _get_stein_solver(_size: int) -> MapReduce: + def _get_stein_solver(_size: int) -> Union[SteinThinning, MapReduce]: """ - Set up Stein Thinning to use ``MapReduce``. + Set up Stein thinning solver. - Create a SteinThinning solver with the specified coreset size, - using ``KernelDensityMatching`` score function for matching on - a subset of the dataset. + If the `leaf_size` is provided, the solver uses ``MapReduce`` to reduce + datasets. :param _size: The size of the coreset to be generated. - :return: MapReduce solver with SteinThinning as the base solver. + :return: A `SteinThinning` solver if `leaf_size` is `None`, otherwise a + `MapReduce` solver with `SteinThinning` as the base solver. """ # Generate small dataset for ScoreMatching for Stein Kernel @@ -152,7 +165,9 @@ def _get_stein_solver(_size: int) -> MapReduce: stein_solver = SteinThinning( coreset_size=_size, kernel=stein_kernel, regularise=False ) - return MapReduce(stein_solver, leaf_size=3 * _size) + if leaf_size is None: + return stein_solver + return MapReduce(stein_solver, leaf_size=leaf_size) def _get_random_solver(_size: int) -> RandomSample: """ @@ -174,29 +189,80 @@ def _get_rp_solver(_size: int) -> RPCholesky: rp_solver = RPCholesky(coreset_size=_size, kernel=kernel, random_key=key) return rp_solver - return [ - _get_random_solver, - _get_rp_solver, - _get_herding_solver, - _get_stein_solver, - _get_thinning_solver, - ] + def _get_compress_solver(_size: int) -> CompressPlusPlus: + """ + Set up Compress++ solver. + :param _size: The size of the coreset to be generated. + :return: A Compress++ solver. + """ + compress_solver = CompressPlusPlus( + coreset_size=_size, + kernel=kernel, + random_key=key, + delta=calculate_delta(num_data_points).item(), + sqrt_kernel=sqrt_kernel, + g=cpp_oversampling_factor, + ) + return compress_solver -def get_solver_name(solver: Callable[[int], Solver]) -> str: - """ - Get the name of the solver. + def _get_probabilistic_herding_solver( + _size: int, + ) -> Union[IterativeKernelHerding, MapReduce]: + """ + Set up KernelHerding with probabilistic selection. - This function extracts and returns the name of the solver class. - If ``_solver`` is an instance of :class:`~coreax.solvers.MapReduce`, it retrieves - the name of the :class:`~coreax.solvers.MapReduce.base_solver` class instead. + If the `leaf_size` is provided, the solver uses ``MapReduce`` to reduce + datasets. - :param solver: An instance of a solver, such as `MapReduce` or `RandomSample`. - :return: The name of the solver class. - """ - # Evaluate solver function to get an instance to interrogate - # Don't just inspect type annotations, as they may be incorrect - not robust - solver_instance = solver(1) - if isinstance(solver_instance, MapReduce): - return type(solver_instance.base_solver).__name__ - return type(solver_instance).__name__ + :param _size: The size of the coreset to be generated. + :return: An `IterativeKernelHerding` solver if `leaf_size` is `None`, otherwise + a `MapReduce` solver with `IterativeKernelHerding` as the base solver. + """ + herding_solver = IterativeKernelHerding( + coreset_size=_size, + kernel=kernel, + probabilistic=True, + temperature=0.001, + random_key=key, + num_iterations=5, + ) + if leaf_size is None: + return herding_solver + return MapReduce(herding_solver, leaf_size=leaf_size) + + def _get_iterative_herding_solver( + _size: int, + ) -> Union[IterativeKernelHerding, MapReduce]: + """ + Set up KernelHerding with probabilistic selection. + + If the `leaf_size` is provided, the solver uses ``MapReduce`` to reduce + datasets. + + :param _size: The size of the coreset to be generated. + :return: An `IterativeKernelHerding` solver if `leaf_size` is `None`, otherwise + a `MapReduce` solver with `IterativeKernelHerding` as the base solver. + """ + herding_solver = IterativeKernelHerding( + coreset_size=_size, + kernel=kernel, + probabilistic=False, + temperature=0.001, + random_key=key, + num_iterations=5, + ) + if leaf_size is None: + return herding_solver + return MapReduce(herding_solver, leaf_size=leaf_size) + + return { + "Random Sample": _get_random_solver, + "RP Cholesky": _get_rp_solver, + "Kernel Herding": _get_herding_solver, + "Stein Thinning": _get_stein_solver, + "Kernel Thinning": _get_thinning_solver, + "Compress++": _get_compress_solver, + "Probabilistic Iterative Herding": _get_probabilistic_herding_solver, + "Iterative Herding": _get_iterative_herding_solver, + } diff --git a/coreax/solvers/__init__.py b/coreax/solvers/__init__.py index 09d500d6..141e1492 100644 --- a/coreax/solvers/__init__.py +++ b/coreax/solvers/__init__.py @@ -27,6 +27,7 @@ GreedyKernelPoints, GreedyKernelPointsState, HerdingState, + IterativeKernelHerding, KernelHerding, KernelThinning, RandomSample, @@ -61,4 +62,5 @@ "CaratheodoryRecombination", "TreeRecombination", "CompressPlusPlus", + "IterativeKernelHerding", ] diff --git a/coreax/solvers/coresubset.py b/coreax/solvers/coresubset.py index fb52bd70..ae7875a3 100644 --- a/coreax/solvers/coresubset.py +++ b/coreax/solvers/coresubset.py @@ -1439,3 +1439,77 @@ def _compress_plus_plus(indices: Array) -> Array: plus_plus_indices = _compress_plus_plus(clipped_indices) return Coresubset(Data(plus_plus_indices), dataset), None + + +class IterativeKernelHerding(ExplicitSizeSolver): + r""" + Iterative Kernel Herding - perform multiple refinements of Kernel Herding. + + Reduce using :class:`~coreax.solvers.KernelHerding` then refine set number of times. + All the parameters except the `num_iterations` are passed to + :class:`~coreax.solvers.KernelHerding`. + + :param coreset_size: The desired size of the solved coreset. + :param kernel: :class:`~coreax.kernels.ScalarValuedKernel` instance implementing a + kernel function. + :math:`k: \\mathbb{R}^d \times \\mathbb{R}^d \rightarrow \\mathbb{R}` + :param unique: Boolean that ensures the resulting coresubset will only contain + unique elements. + :param block_size: Block size passed to + :meth:`~coreax.kernels.ScalarValuedKernel.compute_mean`. + :param unroll: Unroll parameter passed to + :meth:`~coreax.kernels.ScalarValuedKernel.compute_mean`. + :param probabilistic: If :data:`True`, the elements are chosen probabilistically at + each iteration. Otherwise, standard Kernel Herding is run. + :param temperature: Temperature parameter, which controls how uniform the + probabilities are for probabilistic selection. + :param random_key: Key for random number generation, only used if probabilistic + :param num_iterations: Number of refinement iterations. + """ + + num_iterations: int + kernel: ScalarValuedKernel + unique: bool = True + block_size: Optional[Union[int, tuple[Optional[int], Optional[int]]]] = None + unroll: Union[int, bool, tuple[Union[int, bool], Union[int, bool]]] = 1 + probabilistic: bool = False + temperature: Union[float, Scalar] = eqx.field(default=1.0) + random_key: KeyArrayLike = eqx.field(default_factory=lambda: jax.random.key(0)) + + def reduce( + self, + dataset: _Data, + solver_state: Optional[HerdingState] = None, + ) -> tuple[Coresubset[_Data], HerdingState]: + """ + Perform Kernel Herding reduction followed by additional refinement iterations. + + :param dataset: The dataset to process. + :param solver_state: Optional solver state. + :return: Refined coresubset and final solver state. + """ + herding_solver = KernelHerding( + coreset_size=self.coreset_size, + kernel=self.kernel, + unique=self.unique, + block_size=self.block_size, + unroll=self.unroll, + probabilistic=self.probabilistic, + temperature=self.temperature, + random_key=self.random_key, + ) + + coreset, reduced_solver_state = herding_solver.reduce(dataset, solver_state) + + def refine_step(_, state): + coreset, reduced_solver_state = state + coreset, reduced_solver_state = herding_solver.refine( + coreset, reduced_solver_state + ) + return (coreset, reduced_solver_state) + + coreset, reduced_solver_state = lax.fori_loop( + 0, self.num_iterations, refine_step, (coreset, reduced_solver_state) + ) + + return coreset, reduced_solver_state diff --git a/documentation/source/benchmark.rst b/documentation/source/benchmark.rst index bf7303ae..7f21c520 100644 --- a/documentation/source/benchmark.rst +++ b/documentation/source/benchmark.rst @@ -1,12 +1,13 @@ Benchmarking Coreset Algorithms =============================== -In this benchmark, we assess the performance of four different coreset algorithms: +In this benchmark, we assess the performance of different coreset algorithms: :class:`~coreax.solvers.KernelHerding`, :class:`~coreax.solvers.SteinThinning`, :class:`~coreax.solvers.RandomSample`, :class:`~coreax.solvers.RPCholesky` and -:class:`~coreax.solvers.KernelThinning`. Each of these algorithms is evaluated across -four different tests, providing a comparison of their performance and applicability to -various datasets. +:class:`~coreax.solvers.KernelThinning`, :class:`~coreax.solvers.CompressPlusPlus`, +:class:`~coreax.solvers.IterativeKernelHerding`. Each of these algorithms is evaluated +across four different tests, providing a comparison of their performance and +applicability to various datasets. Test 1: Benchmarking Coreset Algorithms on the MNIST Dataset ------------------------------------------------------------ @@ -27,9 +28,10 @@ these steps: into 16 components before applying any coreset algorithm. 4. **Coreset Generation**: Coresets of various sizes are generated using the - different coreset algorithms. For :class:`~coreax.solvers.KernelHerding` and - :class:`~coreax.solvers.SteinThinning`, :class:`~coreax.solvers.MapReduce` is - employed to handle large-scale data. + different coreset algorithms. For :class:`~coreax.solvers.KernelHerding`, + :class:`~coreax.solvers.SteinThinning`, :class:`~coreax.solvers.KernelThinning`, and + :class:`~coreax.solvers.IterativeKernelHerding`, + :class:`~coreax.solvers.MapReduce` is employed to handle large-scale data. 5. **Training**: The model is trained using the selected coresets, and accuracy is measured on the test set of 10,000 images. @@ -39,6 +41,25 @@ these steps: on an **Amazon Web Services EC2 g4dn.12xlarge instance** with 4 NVIDIA T4 Tensor Core GPUs, 48 vCPUs, and 192 GiB memory. +Impact of UMAP and MapReduce on Coreset Performance +--------------------------------------------------- + +In the benchmarking of coreset algorithms, only **Random Sample** can be run without +MapReduce or UMAP without running into memory allocation errors. The other coreset +algorithms require dimensionality reduction and distributed processing to handle +large-scale data efficiently. As a result, the coreset algorithms were not applied +directly to the raw MNIST images. While these preprocessing steps improved efficiency, +they may have impacted the performance of the coreset methods. Specifically, +**MapReduce** partitions the dataset into subsets and applies solvers to each partition, +which can reduce accuracy compared to applying solvers directly to the full dataset. +Additionally, **batch normalisation** and **dropout** were used during training to +mitigate over-fitting. These regularisation techniques made the models more robust, +which also means that accuracy did not heavily depend on the specific subset chosen. +The benchmarking test showed that the accuracy remained similar regardless of +the coreset method used, with only small differences, which could potentially be +attributed to the use of these regularisation techniques. + + **Results**: The accuracy of the MLP classifier when trained using the full MNIST dataset (60,000 training images) was 97.31%, serving as a baseline for evaluating the @@ -64,10 +85,10 @@ Test 2: Benchmarking Coreset Algorithms on a Synthetic Dataset -------------------------------------------------------------- In this second test, we evaluate the performance of the coreset algorithms on a -**synthetic dataset**. The dataset consists of 1,000 points in two-dimensional space, +**synthetic dataset**. The dataset consists of 1,024 points in two-dimensional space, generated using :func:`sklearn.datasets.make_blobs`. The process follows these steps: -1. **Dataset**: A synthetic dataset of 1,000 points is generated to test the +1. **Dataset**: A synthetic dataset of 1,024 points is generated to test the quality of coreset algorithms. 2. **Coreset Generation**: Coresets of different sizes (10, 50, 100, and 200 points) @@ -85,7 +106,7 @@ The tables below show the performance metrics (Unweighted MMD, Unweighted KSD, Weighted MMD, Weighted KSD, and Time) for each coreset algorithm and each coreset size. For each metric and coreset size, the best performance score is highlighted in bold. -.. list-table:: Coreset Size 25 (Original Sample Size 1,000) +.. list-table:: Coreset Size 25 (Original Sample Size 1,024) :header-rows: 1 :widths: 20 15 15 15 15 15 @@ -95,38 +116,56 @@ For each metric and coreset size, the best performance score is highlighted in b - Weighted_MMD - Weighted_KSD - Time - * - KernelHerding - - 0.026319 - - 0.071420 - - 0.008461 - - 0.072526 - - 1.836664 + * - Kernel Herding + - 0.024273 + - 0.072547 + - 0.008471 + - 0.072267 + - 3.859628 * - RandomSample - - 0.105940 - - 0.081013 - - 0.038174 - - *0.077431* - - *1.281091* - * - RPCholesky - - 0.121869 - - *0.059722* - - *0.003283* - - 0.072288 - - 1.576841 - * - SteinThinning - - 0.161923 - - 0.077394 - - 0.030987 - - 0.074365 - - 1.821020 - * - KernelThinning - - *0.014111* - - 0.072134 - - 0.006634 - - 0.072531 - - 9.144707 - -.. list-table:: Coreset Size 50 (Original Sample Size 1,000) + - 0.125471 + - 0.087859 + - 0.037686 + - 0.074856 + - **2.659764** + * - RP Cholesky + - 0.140715 + - **0.059376** + - **0.003011** + - **0.071982** + - 3.312633 + * - Stein Thinning + - 0.165692 + - 0.073476 + - 0.033367 + - 0.073952 + - 3.714297 + * - Kernel Thinning + - 0.014093 + - 0.071987 + - 0.005737 + - 0.072614 + - 23.659113 + * - Compress++ + - 0.010929 + - 0.072254 + - 0.005783 + - 0.072447 + - 15.278997 + * - Probabilistic Iterative Herding + - 0.017470 + - 0.074181 + - 0.007226 + - 0.072694 + - 4.330906 + * - IIterative Herding + - **0.006842** + - 0.072133 + - 0.004978 + - 0.072212 + - 3.399839 + +.. list-table:: Coreset Size 50 (Original Sample Size 1,024) :header-rows: 1 :widths: 20 15 15 15 15 15 @@ -136,38 +175,56 @@ For each metric and coreset size, the best performance score is highlighted in b - Weighted_MMD - Weighted_KSD - Time - * - KernelHerding - - 0.012574 - - 0.072600 - - 0.003843 - - *0.072351* - - 1.863356 + * - Kernel Herding + - 0.014011 + - 0.072273 + - 0.003191 + - 0.072094 + - 3.417109 * - RandomSample - - 0.083379 - - 0.079031 - - 0.008653 - - 0.072867 - - *1.329118* - * - RPCholesky - - 0.154799 - - *0.056437* - - *0.001347* - - 0.072359 - - 1.564009 - * - SteinThinning - - 0.122605 - - 0.079683 - - 0.012048 - - 0.072424 - - 1.849748 - * - KernelThinning - - *0.005397* - - 0.072051 - - 0.002191 - - 0.072453 - - 5.524234 - -.. list-table:: Coreset Size 100 (Original Sample Size 1,000) + - 0.100558 + - 0.080291 + - 0.005518 + - 0.072549 + - **2.575190** + * - RP Cholesky + - 0.136605 + - **0.055552** + - **0.001971** + - 0.072116 + - 3.227958 + * - Stein Thinning + - 0.152293 + - 0.073183 + - 0.017996 + - **0.071682** + - 4.056369 + * - Kernel Thinning + - 0.006482 + - 0.071823 + - 0.002541 + - 0.072183 + - 12.507483 + * - Compress++ + - 0.006065 + - 0.071981 + - 0.002633 + - 0.072257 + - 9.339439 + * - Probabilistic Iterative Herding + - 0.010031 + - 0.072707 + - 0.002906 + - 0.072432 + - 4.279948 + * - IIterative Herding + - **0.003546** + - 0.072107 + - 0.002555 + - 0.072203 + - 3.291645 + +.. list-table:: Coreset Size 100 (Original Sample Size 1,024) :header-rows: 1 :widths: 20 15 15 15 15 15 @@ -177,38 +234,56 @@ For each metric and coreset size, the best performance score is highlighted in b - Weighted_MMD - Weighted_KSD - Time - * - KernelHerding - - 0.007651 - - *0.071999* - - 0.001814 - - 0.072364 - - 2.185324 + * - Kernel Herding + - 0.007909 + - 0.071763 + - 0.001859 + - 0.072205 + - 3.583433 * - RandomSample - - 0.052402 - - 0.077454 - - 0.001630 - - 0.072480 - - *1.359826* - * - RPCholesky - - 0.087236 - - 0.063822 - - *0.000910* - - 0.072433 - - 1.721290 - * - SteinThinning - - 0.128295 - - 0.082733 - - 0.006041 - - *0.072182* - - 1.893099 - * - KernelThinning - - *0.002591* - - 0.072293 - - 0.001207 - - 0.072394 - - 3.519274 - -.. list-table:: Coreset Size 200 (Original Sample Size 1,000) + - 0.067373 + - 0.077506 + - 0.001673 + - 0.072329 + - **2.631034** + * - RP Cholesky + - 0.091372 + - **0.059889** + - **0.001174** + - 0.072281 + - 3.426726 + * - Stein Thinning + - 0.102536 + - 0.074250 + - 0.007770 + - **0.071809** + - 3.673147 + * - Kernel Thinning + - 0.002811 + - 0.072218 + - 0.001414 + - 0.072286 + - 7.878599 + * - Compress++ + - 0.003343 + - 0.072287 + - 0.001486 + - 0.072283 + - 6.930467 + * - Probabilistic Iterative Herding + - 0.006254 + - 0.072408 + - 0.001649 + - 0.072289 + - 4.381068 + * - IIterative Herding + - **0.002130** + - 0.072142 + - 0.001373 + - 0.072248 + - 3.502385 + +.. list-table:: Coreset Size 200 (Original Sample Size 1,024) :header-rows: 1 :widths: 20 15 15 15 15 15 @@ -218,45 +293,88 @@ For each metric and coreset size, the best performance score is highlighted in b - Weighted_MMD - Weighted_KSD - Time - * - KernelHerding - - 0.004310 - - 0.072341 - - 0.000777 - - 0.072422 - - 1.837929 + * - Kernel Herding + - 0.004259 + - 0.072017 + - 0.001173 + - 0.072242 + - 3.810858 * - RandomSample - - 0.036624 - - 0.072870 - - *0.000584* - - 0.072441 - - *1.367439* - * - RPCholesky - - 0.041140 - - *0.068655* - - 0.000751 - - 0.072430 - - 2.106838 - * - SteinThinning - - 0.148525 - - 0.087512 - - 0.003799 - - *0.072164* - - 1.910560 - * - KernelThinning - - *0.001330* - - 0.072348 - - 0.001014 - - 0.072428 - - 2.565189 + - 0.031644 + - 0.074061 + - 0.001005 + - 0.072271 + - **2.787691** + * - RP Cholesky + - 0.052786 + - **0.065218** + - **0.000784** + - 0.072269 + - 3.545290 + * - Stein Thinning + - 0.098395 + - 0.078290 + - 0.004569 + - **0.071896** + - 3.910901 + * - Kernel Thinning + - **0.001175** + - 0.072160 + - 0.000933 + - 0.072273 + - 5.720256 + * - Compress++ + - 0.001336 + - 0.072193 + - 0.000788 + - 0.072228 + - 6.081252 + * - Probabilistic Iterative Herding + - 0.005056 + - 0.072054 + - 0.000852 + - 0.072287 + - 5.043387 + * - IIterative Herding + - 0.001346 + - 0.072169 + - 0.001020 + - 0.072241 + - 3.699600 + **Visualisation**: The results in this table can be visualised as follows: - .. image:: ../../examples/benchmarking_images/blobs_benchmark_results.png - :alt: Benchmark Results for Synthetic Dataset + .. image:: ../../examples/benchmarking_images/blobs_unweighted_mmd.png + :alt: Line graph visualising the data tables above, plotting unweighted MMD against + coreset size for each of the coreset methods + + **Figure 3**: Unweighted MMD plotted against coreset size for each coreset method. + + .. image:: ../../examples/benchmarking_images/blobs_unweighted_ksd.png + :alt: Line graph visualising the data tables above, plotting unweighted KSD against + coreset size for each of the coreset methods + + **Figure 4**: Unweighted KSD plotted against coreset size for each coreset method. + + .. image:: ../../examples/benchmarking_images/blobs_weighted_mmd.png + :alt: Line graph visualising the data tables above, plotting weighted MMD against + coreset size for each of the coreset methods - **Figure 3**: Line graphs depicting the average performance metrics across 5 runs of - each coreset algorithm on a synthetic dataset. + **Figure 5**: Weighted MMD plotted against coreset size for each coreset method. + + .. image:: ../../examples/benchmarking_images/blobs_weighted_ksd.png + :alt: Line graph visualising the data tables above, plotting weighted KSD against + coreset size for each of the coreset methods + + **Figure 6**: Weighted KSD plotted against coreset size for each coreset method. + + .. image:: ../../examples/benchmarking_images/blobs_time_taken.png + :alt: Line graph visualising the data tables above, plotting time taken against + coreset size for each of the coreset methods + + **Figure 7**: Time taken plotted against coreset size for each coreset method. Test 3: Benchmarking Coreset Algorithms on Pixel Data from an Image ------------------------------------------------------------------- @@ -277,72 +395,100 @@ from an input image. The process follows these steps: **Results**: The following plot visualises the pixels chosen by each coreset algorithm. .. image:: ../../examples/benchmarking_images/david_benchmark_results.png - :alt: Coreset Visualisation on Image + :alt: Plot showing pixels chosen from an image by each coreset algorithm - **Figure 4**: The original image and pixels selected by each coreset algorithm + **Figure 8**: The original image and pixels selected by each coreset algorithm plotted side-by-side for visual comparison. -Test 4: Benchmarking Coreset Algorithms on Frame Data from a GIF ----------------------------------------------------------------- +Test 4: Selecting Key Frames from Video Data +-------------------------------------------- The fourth and final test evaluates the performance of coreset algorithms on data -extracted from an input **GIF**. This test involves the following steps: +extracted from an input animated **Video**. This test involves the following steps: -1. **Input GIF**: A GIF is loaded, and its frames are preprocessed. +1. **Input Video**: A video is loaded, and its frames are preprocessed. 2. **Dimensionality Reduction**: On each frame data, a density preserving :class:`~umap.umap_.UMAP` is applied to reduce dimensionality of each frame to 25. -3. **Coreset Generation**: Coresets are generated using each coreset algorithm, and - selected frames are saved as new GIFs. +3. **Coreset Generation**: For each coreset algorithm, coresets are generated and + selected frames are saved as new video. **Result**: -- GIF files showing the selected frames for each coreset algorithm. +- Video files showing the selected frames for each coreset algorithm. .. image:: ../../examples/pounce/pounce.gif - :alt: Coreset Visualisation on GIF Frames + :alt: Original video showing the sequence of frames before applying + coreset algorithms. - **Gif 1**: Original gif file. + **Video 1**: Original video file. .. image:: ../../examples/benchmarking_images/RandomSample_coreset.gif - :alt: Coreset Visualisation on GIF Frames + :alt: Video showing the frames selected by Random Sample - **Gif 2**: Frames selected by Random Sample. + **Video 2**: Frames selected by Random Sample. .. image:: ../../examples/benchmarking_images/SteinThinning_coreset.gif - :alt: Coreset Visualisation on GIF Frames + :alt: Video showing the frames selected by Stein Thinning - **Gif 3**: Frames selected by Stein thinning. + **Video 3**: Frames selected by Stein thinning. .. image:: ../../examples/benchmarking_images/RPCholesky_coreset.gif - :alt: Coreset Visualisation on GIF Frames + :alt: Video showing the frames selected by RP Cholesky - **Gif 4**: Frames selected by RP Cholesky. + **Video 4**: Frames selected by RP Cholesky. .. image:: ../../examples/benchmarking_images/KernelHerding_coreset.gif - :alt: Coreset Visualisation on GIF Frames + :alt: Video showing the frames selected by Kernel Herding - **Gif 5**: Frames selected by kernel herding. + **Video 5**: Frames selected by Kernel Herding. - .. image:: ../../examples/benchmarking_images/pounce_frames.png - :alt: Coreset Visualisation on GIF Frames + .. image:: ../../examples/benchmarking_images/KernelThinning_coreset.gif + :alt: Video showing the frames selected by Kernel Thinning - **Figure 5**:Frames chosen by each each coreset algorithm with action frames (the - frames in which pounce action takes place) highlighted in red. + **Video 6**: Frames selected by Kernel Thinning. -Conclusion ----------- + .. image:: ../../examples/benchmarking_images/CompressPlusPlus_coreset.gif + :alt: Video showing the frames selected by Compress++ -In this benchmark, we evaluated four coreset algorithms across various datasets and -tasks, including image classification, synthetic datasets, and pixel/frame data -processing. Based on the results, **kernel thinning** emerges as the preferred choice -for most tasks due to its consistent performance. For larger datasets, -combining kernel herding with distributed frameworks like **map reduce** is -recommended to ensure scalability and efficiency. + **Video 7**: Frames selected by Compress++. -For specialised tasks, such as frame selection from GIFs (Test 4), **Stein thinning** -demonstrated superior performance and may be the optimal choice. + .. image:: ../../examples/benchmarking_images/ProbabilisticIterativeHerding_coreset.gif + :alt: Video showing the frames selected by Probabilistic Iterative Kernel Herding + + **Video 8**: Frames selected by Probabilistic Iterative Kernel Herding. + +The following plots show the frames chosen by each coreset algorithm with action frames +in orange. + + .. image:: ../../examples/benchmarking_images/frames_random_sample.png + :alt: Plot showing the frames selected by Random Sample + + .. image:: ../../examples/benchmarking_images/frames_rp_cholesky.png + :alt: Plot showing the frames selected by RP Cholesky + + .. image:: ../../examples/benchmarking_images/frames_stein_thinning.png + :alt: Plot showing the frames selected by Stein Thinning + + .. image:: ../../examples/benchmarking_images/frames_kernel_herding.png + :alt: Plot showing the frames selected by Kernel Herding + + .. image:: ../../examples/benchmarking_images/frames_kernel_thinning.png + :alt: Plot showing the frames selected by Kernel Thinning + + .. image:: ../../examples/benchmarking_images/frames_compress_plus_plus.png + :alt: Plot showing the frames selected by Compress++ + + .. image:: ../../examples/benchmarking_images/frames_probabilistic_iterative_herding.png + :alt: Plot showing the frames selected by Probabilistic Iterative Kernel Herding + +Conclusion +---------- +This benchmark evaluated four coreset algorithms across various tasks, including image +classification and frame selection. *Iterative kernel herding* and *kernel thinning* +emerged as the top performers, offering strong and consistent results. For large-scale +datasets, *compress++* and *map reduce* provide efficient scalability. Ultimately, this conclusion reflects one interpretation of the results, and readers are encouraged to analyse the benchmarks and derive their own insights based on the specific diff --git a/documentation/source/conf.py b/documentation/source/conf.py index d375ada1..488febd4 100644 --- a/documentation/source/conf.py +++ b/documentation/source/conf.py @@ -209,6 +209,7 @@ ("py:obj", "coreax.metrics._Data"), ("py:obj", "coreax.solvers.coresubset._SupervisedData"), ("py:obj", "coreax.util.T"), + ("py:class", "pathlib._local.Path"), ] nitpick_ignore_regex = [ diff --git a/examples/benchmarking_images/CompressPlusPlus_coreset.gif b/examples/benchmarking_images/CompressPlusPlus_coreset.gif new file mode 100644 index 00000000..98539b9b --- /dev/null +++ b/examples/benchmarking_images/CompressPlusPlus_coreset.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9ce663e91986bca48780766ff7c7f97a5999a9db0fa29bfb3eeb8320ebf5789f +size 333975 diff --git a/examples/benchmarking_images/KernelHerding_coreset.gif b/examples/benchmarking_images/KernelHerding_coreset.gif index 963226b8..6429b12b 100644 --- a/examples/benchmarking_images/KernelHerding_coreset.gif +++ b/examples/benchmarking_images/KernelHerding_coreset.gif @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0121f030f9a793e45135c56f5173ebdf5ca74f23c2285270d6dc0cb44aa96ab7 -size 331155 +oid sha256:52648a2739f0f8bff97f2415f70f6055e0f10383d7f1160c01ef99f15a54585f +size 339508 diff --git a/examples/benchmarking_images/KernelThinning_coreset.gif b/examples/benchmarking_images/KernelThinning_coreset.gif index eca296c5..8f9fff75 100644 --- a/examples/benchmarking_images/KernelThinning_coreset.gif +++ b/examples/benchmarking_images/KernelThinning_coreset.gif @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:093a0f1c5309307cd35352d388f153fe0e27718da4eaf11d394c4dc754eae10a -size 336313 +oid sha256:de19aa298471ee0dedf74c7f4a6c7191b24bf3f83444b405103c67c9ee13c1d9 +size 336133 diff --git a/examples/benchmarking_images/ProbabilisticIterativeHerding_coreset.gif b/examples/benchmarking_images/ProbabilisticIterativeHerding_coreset.gif new file mode 100644 index 00000000..80fa749d --- /dev/null +++ b/examples/benchmarking_images/ProbabilisticIterativeHerding_coreset.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:49bff4bca774f963f41ff12e109b099c978ca41f8277031ffc5cefa934f2a014 +size 332878 diff --git a/examples/benchmarking_images/RPCholesky_coreset.gif b/examples/benchmarking_images/RPCholesky_coreset.gif index d8333f4b..5dd253d0 100644 --- a/examples/benchmarking_images/RPCholesky_coreset.gif +++ b/examples/benchmarking_images/RPCholesky_coreset.gif @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:6f5c9e864cb4f92e6fcd65b8f76f67bdda602c82dca077eff28d38c6356aabbd -size 319587 +oid sha256:c1cb27571044b8e63f22d0b412b2eb638371b027616d316b5d1bbc00010620cf +size 323131 diff --git a/examples/benchmarking_images/SteinThinning_coreset.gif b/examples/benchmarking_images/SteinThinning_coreset.gif index 498aa6a9..b3ed5017 100644 --- a/examples/benchmarking_images/SteinThinning_coreset.gif +++ b/examples/benchmarking_images/SteinThinning_coreset.gif @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:8f7edf1b11abf5ee8e0ba8fa09eea7221cfc3881905898e4fa64236bfb43e771 -size 314507 +oid sha256:d55d0f50b2e494aefe2b107253235e4b244920bc526e813478f06f5c255fa571 +size 315021 diff --git a/examples/benchmarking_images/blobs_benchmark_results.png b/examples/benchmarking_images/blobs_benchmark_results.png index 8878c186..0ba8e4cb 100644 --- a/examples/benchmarking_images/blobs_benchmark_results.png +++ b/examples/benchmarking_images/blobs_benchmark_results.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d144b777b2cddc86191bfdef932c8bf55f9e08f22e165650c6941a4733174b59 -size 152391 +oid sha256:801fefbef7ba384f5483568b16f14d43becba1f0bf95540946e032d0c49ba7d1 +size 198141 diff --git a/examples/benchmarking_images/blobs_time_taken.png b/examples/benchmarking_images/blobs_time_taken.png new file mode 100644 index 00000000..5ef1a51f --- /dev/null +++ b/examples/benchmarking_images/blobs_time_taken.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0103f52cc6d9e02de5f9b402e57257e0456cb5411d07825630af844cc6ee94ba +size 115777 diff --git a/examples/benchmarking_images/blobs_unweighted_ksd.png b/examples/benchmarking_images/blobs_unweighted_ksd.png new file mode 100644 index 00000000..db066d25 --- /dev/null +++ b/examples/benchmarking_images/blobs_unweighted_ksd.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6cd72e6b6b5c898f2edb73dea7db3b03e2e7db33442caba4574a29e5d585378e +size 114135 diff --git a/examples/benchmarking_images/blobs_unweighted_mmd.png b/examples/benchmarking_images/blobs_unweighted_mmd.png new file mode 100644 index 00000000..3c4b00e2 --- /dev/null +++ b/examples/benchmarking_images/blobs_unweighted_mmd.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d4a97c1660d40af9546767709f46034975d5975abe210cb6b29ae932fa9cbd5 +size 154915 diff --git a/examples/benchmarking_images/blobs_weighted_ksd.png b/examples/benchmarking_images/blobs_weighted_ksd.png new file mode 100644 index 00000000..fc2373ee --- /dev/null +++ b/examples/benchmarking_images/blobs_weighted_ksd.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a22a460aaa6b28ad667e182cf3e911e1a9ff3433156215e1c0a25eab3e210da8 +size 123498 diff --git a/examples/benchmarking_images/blobs_weighted_mmd.png b/examples/benchmarking_images/blobs_weighted_mmd.png new file mode 100644 index 00000000..17dc744e --- /dev/null +++ b/examples/benchmarking_images/blobs_weighted_mmd.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a3f3f41fa44868523a0b8ec339c313a50d0575926689ee6117dc34662ae94d35 +size 170013 diff --git a/examples/benchmarking_images/david_benchmark_results.png b/examples/benchmarking_images/david_benchmark_results.png index 44b265ec..41e53f65 100644 --- a/examples/benchmarking_images/david_benchmark_results.png +++ b/examples/benchmarking_images/david_benchmark_results.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:44b968f02e3ab99502592227f452baf8c870dd73039f4b45df68ae9e9c475417 -size 757481 +oid sha256:7c2f563f928102001022fc4e95177360a0d5114e5f27d10e9af5876e81ae8b99 +size 463068 diff --git a/examples/benchmarking_images/frames_compress_plus_plus.png b/examples/benchmarking_images/frames_compress_plus_plus.png new file mode 100644 index 00000000..a039bd1b --- /dev/null +++ b/examples/benchmarking_images/frames_compress_plus_plus.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:72f9b4b33d5e064317e0f17243f66503e08d36559f8b1a82d215c363c5308566 +size 24775 diff --git a/examples/benchmarking_images/frames_kernel_herding.png b/examples/benchmarking_images/frames_kernel_herding.png new file mode 100644 index 00000000..2e7bdd9b --- /dev/null +++ b/examples/benchmarking_images/frames_kernel_herding.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b0fcb5d52e4c7a4a9ec29a4a64c93b043fe2207efd9950764ba8e554c20ff39 +size 23793 diff --git a/examples/benchmarking_images/frames_kernel_thinning.png b/examples/benchmarking_images/frames_kernel_thinning.png new file mode 100644 index 00000000..dffc4292 --- /dev/null +++ b/examples/benchmarking_images/frames_kernel_thinning.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ab8206f46b2f8f22ca04400046a0847049b149f863d617939293ecaae140ec71 +size 24186 diff --git a/examples/benchmarking_images/frames_probabilistic_iterative_herding.png b/examples/benchmarking_images/frames_probabilistic_iterative_herding.png new file mode 100644 index 00000000..f17d7b58 --- /dev/null +++ b/examples/benchmarking_images/frames_probabilistic_iterative_herding.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:912e401aeb4a3f697980482f90528d08469660c49fbbdc865b5cb0828deb4f0a +size 26897 diff --git a/examples/benchmarking_images/frames_random_sample.png b/examples/benchmarking_images/frames_random_sample.png new file mode 100644 index 00000000..00a939a9 --- /dev/null +++ b/examples/benchmarking_images/frames_random_sample.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7846d206d6a3cce09cc07da1e7336d0872212e13e16d2a2b29a83371ba1a5827 +size 25004 diff --git a/examples/benchmarking_images/frames_rp_cholesky.png b/examples/benchmarking_images/frames_rp_cholesky.png new file mode 100644 index 00000000..bdc9d940 --- /dev/null +++ b/examples/benchmarking_images/frames_rp_cholesky.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d46bc9a9968f3f1e904832e50b3b14bfc8211b77235c23aadf259e35b1ef602d +size 24668 diff --git a/examples/benchmarking_images/frames_stein_thinning.png b/examples/benchmarking_images/frames_stein_thinning.png new file mode 100644 index 00000000..73171e6a --- /dev/null +++ b/examples/benchmarking_images/frames_stein_thinning.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec6d9cce86b54af586ebb45c9d36c82fd2b6a7888f737dcbacd50aa58e4d649c +size 23755 diff --git a/examples/benchmarking_images/mnist_benchmark_accuracy.png b/examples/benchmarking_images/mnist_benchmark_accuracy.png index ca35c12c..5b53e912 100644 --- a/examples/benchmarking_images/mnist_benchmark_accuracy.png +++ b/examples/benchmarking_images/mnist_benchmark_accuracy.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2c705227205915da97a4c80cabee1048caba54995928c0b623271724a00d41c7 -size 65234 +oid sha256:bf0ff143d931f1e81a9ada964f875ed0fb7917e7c73959e820d8033a0f2cc1a8 +size 139614 diff --git a/examples/benchmarking_images/mnist_benchmark_time_taken.png b/examples/benchmarking_images/mnist_benchmark_time_taken.png index 1df5c846..2bba41c9 100644 --- a/examples/benchmarking_images/mnist_benchmark_time_taken.png +++ b/examples/benchmarking_images/mnist_benchmark_time_taken.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5dfe92ce3124097afb476f2827620c05109f4c8352a0a8e96f73c79522bbc4bd -size 52228 +oid sha256:fe77ac366df955f0dd4da5532dcd9be67050c81d1fac1202d75719145004ce1a +size 109829 diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py index b273e8b2..414d07f1 100644 --- a/tests/unit/test_benchmark.py +++ b/tests/unit/test_benchmark.py @@ -33,21 +33,23 @@ train_and_evaluate, ) from coreax import Data -from coreax.benchmark_util import calculate_delta, get_solver_name, initialise_solvers -from coreax.kernels.scalar_valued import SquaredExponentialKernel +from coreax.benchmark_util import calculate_delta, initialise_solvers from coreax.solvers import ( + CompressPlusPlus, + IterativeKernelHerding, KernelHerding, + KernelThinning, MapReduce, RandomSample, RPCholesky, + SteinThinning, ) class MockDataset(VisionDataset): """Mock dataset class for testing purposes.""" - # We deliberately don't call super().__init__(), as this is a mock class - def __init__(self, data: torch.Tensor, labels: torch.Tensor) -> None: # pylint: disable=super-init-not-called + def __init__(self, data: torch.Tensor, labels: torch.Tensor) -> None: """ Initialise the MockDataset. @@ -56,6 +58,7 @@ def __init__(self, data: torch.Tensor, labels: torch.Tensor) -> None: # pylint: :param data: A tensor containing the dataset features. :param labels: A tensor containing the corresponding labels. """ + super().__init__(root="", transform=None, target_transform=None) self.data = data self.labels = labels @@ -135,41 +138,79 @@ def test_initialise_solvers() -> None: """ Test the :func:`initialise_solvers`. - Verify that the returned list contains callable functions that produce + Verify that the returned dictionary contains callable functions that produce valid solver instances. """ # Create a mock dataset (UMAP-transformed) with arbitrary values mock_data = Data(jnp.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]])) key = random.PRNGKey(42) - - solvers = initialise_solvers(mock_data, key) - for solver in solvers: - solver_instance = solver(1) # Instantiate with a coreset size of 1 - assert isinstance(solver_instance, (MapReduce, RandomSample, RPCholesky)), ( - f"Unexpected solver type: {type(solver_instance)}" - ) - - -def test_get_solver_name(): + cpp_oversampling_factor = 1 + + # Initialise solvers + solvers = initialise_solvers(mock_data, key, cpp_oversampling_factor) + + # Ensure solvers is a dictionary with the expected keys + expected_solver_keys = [ + "Random Sample", + "RP Cholesky", + "Kernel Herding", + "Stein Thinning", + "Kernel Thinning", + "Compress++", + "Probabilistic Iterative Herding", + "Iterative Herding", + ] + assert set(solvers.keys()) == set(expected_solver_keys) + + +def test_solver_instances() -> None: """ - Test `get_solver_name` function to ensure it returns correct solver names. + Test :func:`initialise_solvers` returns an instance of the expected solver type. """ - # Create a KernelHerding solver - herding_solver = KernelHerding(coreset_size=5, kernel=SquaredExponentialKernel()) + mock_data = Data(jnp.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]])) + key = random.PRNGKey(42) + cpp_oversampling_factor = 1 + # Case 1: When leaf_size is not provided + solvers_no_leaf = initialise_solvers(mock_data, key, cpp_oversampling_factor) + + expected_solver_types_no_leaf = { + "Random Sample": RandomSample, + "RP Cholesky": RPCholesky, + "Kernel Herding": KernelHerding, + "Stein Thinning": SteinThinning, + "Kernel Thinning": KernelThinning, + "Compress++": CompressPlusPlus, + "Probabilistic Iterative Herding": IterativeKernelHerding, + "Iterative Herding": IterativeKernelHerding, + } - # Wrap it in MapReduce - map_reduce_solver = MapReduce(base_solver=herding_solver, leaf_size=15) + for solver_name, solver_function in solvers_no_leaf.items(): + solver_instance = solver_function(1) + assert isinstance(solver_instance, expected_solver_types_no_leaf[solver_name]) - assert get_solver_name(lambda _: herding_solver) == "KernelHerding", ( - "Expected 'KernelHerding' but got something else." + # Case 2: When leaf_size is provided + leaf_size = 2 + solvers_with_leaf = initialise_solvers( + mock_data, key, cpp_oversampling_factor, leaf_size ) - assert get_solver_name(lambda _: map_reduce_solver) == "KernelHerding", ( - "Expected 'KernelHerding' from MapReduce solver but got something else." - ) + expected_solver_types_with_leaf = { + "Random Sample": RandomSample, + "RP Cholesky": RPCholesky, + "Kernel Herding": MapReduce, + "Stein Thinning": MapReduce, + "Kernel Thinning": MapReduce, + "Compress++": CompressPlusPlus, + "Probabilistic Iterative Herding": MapReduce, + "Iterative Herding": MapReduce, + } + + for solver_name, solver_function in solvers_with_leaf.items(): + solver_instance = solver_function(1) + assert isinstance(solver_instance, expected_solver_types_with_leaf[solver_name]) -@pytest.mark.parametrize("n", [10, 100, 1000]) +@pytest.mark.parametrize("n", [1, 2, 100]) def test_calculate_delta(n): """ Test the `calculate_delta` function. @@ -177,7 +218,7 @@ def test_calculate_delta(n): Ensure that the function produces a positive delta value for different values of n. """ delta = calculate_delta(n) - assert delta > 0, f"Delta should be positive but got {delta} for n={n}" + assert delta > 0 if __name__ == "__main__": diff --git a/tests/unit/test_solvers.py b/tests/unit/test_solvers.py index fd1408ad..b855c9c0 100644 --- a/tests/unit/test_solvers.py +++ b/tests/unit/test_solvers.py @@ -59,6 +59,7 @@ GreedyKernelPoints, GreedyKernelPointsState, HerdingState, + IterativeKernelHerding, KernelHerding, KernelThinning, MapReduce, @@ -114,7 +115,7 @@ class SolverTest: shape: tuple[int, int] = (128, 10) @abstractmethod - def solver_factory(self) -> jtu.Partial: + def solver_factory(self, request: pytest.FixtureRequest) -> jtu.Partial: """ Pytest fixture that returns a partially applied solver initialiser. @@ -540,7 +541,8 @@ class TestKernelHerding(RefinementSolverTest, ExplicitSizeSolverTest): @override @pytest.fixture(scope="class") - def solver_factory(self) -> jtu.Partial: + def solver_factory(self, request: pytest.FixtureRequest) -> jtu.Partial: + del request kernel = PCIMQKernel() coreset_size = self.shape[0] // 10 return jtu.Partial(KernelHerding, coreset_size=coreset_size, kernel=kernel) @@ -1128,7 +1130,8 @@ def check_solution_invariants( @override @pytest.fixture(scope="class") - def solver_factory(self) -> jtu.Partial: + def solver_factory(self, request: pytest.FixtureRequest) -> jtu.Partial: + del request coreset_size = self.shape[0] // 10 key = jr.fold_in(self.random_key, self.shape[0]) return jtu.Partial(RandomSample, coreset_size=coreset_size, random_key=key) @@ -1151,7 +1154,8 @@ def check_solution_invariants( @override @pytest.fixture(scope="class") - def solver_factory(self) -> jtu.Partial: + def solver_factory(self, request) -> jtu.Partial: + del request kernel = PCIMQKernel() coreset_size = self.shape[0] // 10 return jtu.Partial( @@ -1436,7 +1440,8 @@ class TestSteinThinning(RefinementSolverTest, ExplicitSizeSolverTest): @override @pytest.fixture(scope="class") - def solver_factory(self) -> jtu.Partial: + def solver_factory(self, request: pytest.FixtureRequest) -> jtu.Partial: + del request kernel = PCIMQKernel() coreset_size = self.shape[0] // 10 return jtu.Partial(SteinThinning, coreset_size=coreset_size, kernel=kernel) @@ -1802,7 +1807,8 @@ class TestGreedyKernelPoints(RefinementSolverTest, ExplicitSizeSolverTest): @override @pytest.fixture(scope="class") - def solver_factory(self) -> jtu.Partial: + def solver_factory(self, request) -> jtu.Partial: + del request feature_kernel = SquaredExponentialKernel() coreset_size = self.shape[0] // 10 return jtu.Partial( @@ -1957,7 +1963,9 @@ class TestMapReduce(SolverTest): @override @pytest.fixture(scope="class") - def solver_factory(self) -> jtu.Partial: + def solver_factory(self, request) -> jtu.Partial: + del request + class _MockTree: def __init__(self, _data: np.ndarray, **kwargs): del kwargs @@ -2263,7 +2271,8 @@ class TestCaratheodoryRecombination(RecombinationSolverTest): @override @pytest.fixture(scope="class") - def solver_factory(self) -> jtu.Partial: + def solver_factory(self, request: pytest.FixtureRequest) -> jtu.Partial: + del request return jtu.Partial(CaratheodoryRecombination, test_functions=None, rcond=None) @@ -2272,7 +2281,8 @@ class TestTreeRecombination(RecombinationSolverTest): @override @pytest.fixture(scope="class") - def solver_factory(self) -> jtu.Partial: + def solver_factory(self, request: pytest.FixtureRequest) -> jtu.Partial: + del request return jtu.Partial( TreeRecombination, test_functions=None, rcond=None, tree_reduction_factor=3 ) @@ -2283,7 +2293,8 @@ class TestKernelThinning(ExplicitSizeSolverTest): @override @pytest.fixture(scope="class") - def solver_factory(self) -> jtu.Partial: + def solver_factory(self, request: pytest.FixtureRequest) -> jtu.Partial: + del request kernel = PCIMQKernel() coreset_size = self.shape[0] // 10 return jtu.Partial( @@ -2500,7 +2511,8 @@ class TestCompressPlusPlus(ExplicitSizeSolverTest): @override @pytest.fixture(scope="class") - def solver_factory(self) -> jtu.Partial: + def solver_factory(self, request: pytest.FixtureRequest) -> jtu.Partial: + del request kernel = SquaredExponentialKernel() coreset_size = self.shape[0] // 8 return jtu.Partial( @@ -2551,3 +2563,21 @@ def test_invalid_coreset_size_incompatible(self): sqrt_kernel=SquaredExponentialKernel(), ) solver.reduce(dataset) + + +class TestIterativeKernelHerding(ExplicitSizeSolverTest): + """Test cases for :class:`coreax.solvers.coresubset.KernelThinning`.""" + + @override + @pytest.fixture(scope="class", params=[True, False]) + def solver_factory(self, request: pytest.FixtureRequest) -> jtu.Partial: + kernel = PCIMQKernel() + coreset_size = self.shape[0] // 10 + return jtu.Partial( + IterativeKernelHerding, + coreset_size=coreset_size, + random_key=self.random_key, + kernel=kernel, + probabilistic=request.param, + num_iterations=2, + )