Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sml/kmeans #277

Merged
merged 1 commit into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions sml/kmeans/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test")

package(default_visibility = ["//visibility:public"])

py_library(
name = "kmeans",
srcs = ["kmeans.py"],
deps = [
"//sml/utils:fxp_approx",
],
)

py_binary(
name = "kmeans_emul",
srcs = ["kmeans_emul.py"],
deps = [
":kmeans",
"//examples/python/utils:dataset_utils", # FIXME: remove examples dependency
"//sml/utils:emulation"
],
)



py_test(
name = "kmeans_test",
srcs = ["kmeans_test.py"],
data = [
"//examples/python/conf", # FIXME: remove examples dependency
],
deps = [
":kmeans",
"//examples/python/utils:dataset_utils", # FIXME: remove examples dependency
"//spu:init",
"//spu/utils:simulation",
],
)
86 changes: 86 additions & 0 deletions sml/kmeans/kmeans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import jax
import jax.numpy as jnp

class KMEANS:
"""
Parameters
----------
n_clusters : int
The number of clusters to form as well as the number of
centroids to generate.

n_samples : int
The number of samples.

max_iter : int, default=300
Maximum number of iterations of the k-means algorithm for a
single run.

tol : float, default=1e-4
Acceptable error to consider the two to be equal.
"""
def __init__(self, n_clusters, n_samples, max_iter=300, tol=1e-4):
self.n_clusters = n_clusters
self.max_iter = max_iter
self.tol = tol
self.init_params = jax.random.randint(jax.random.PRNGKey(1),shape=[self.n_clusters],minval=0,maxval=n_samples)
self._centers = jnp.zeros(())

def fit(self, x):
"""Fit KMEANS.

Firstly, randomly select the initial centers. Then calculate the distance between each sample and each center,
and assign each sample to the nearest center. Use an `aligned_array` to indicate the samples in a cluster,
where unrelated samples will be set to 0. Once all samples are assigned, the center of each cluster will
be updated to the average. The average could be got by `sum(data * aligned_array) / sum(aligned_array)`.
Different clusters could use broadcast for better performance.

Parameters
----------
x : {array-like}, shape (n_samples, n_features)
Input data.

Returns
-------
self : object
Returns an instance of self.
"""

centers = jnp.array([x[i] for i in self.init_params])
for _ in range(self.max_iter):
C = x.reshape((1, x.shape[0], x.shape[1])) - centers.reshape((centers.shape[0], 1, centers.shape[1]))
C = jnp.argmin(jnp.sum(jnp.square(C), axis=2), axis=0)

S = jnp.tile(C,(self.n_clusters,1))
ks = jnp.arange(self.n_clusters)
aligned_array_raw = (S.T - ks).T
aligned_array = jnp.equal(aligned_array_raw, 0)

centers_raw = x.reshape((1, x.shape[0], x.shape[1])) * aligned_array.reshape((aligned_array.shape[0],aligned_array.shape[1],1))
equals_sum = jnp.sum(aligned_array,axis=1)
centers_sum = jnp.sum(centers_raw, axis=1)
centers = jnp.divide(centers_sum.T, equals_sum).T

self._centers = centers
return self

def predict(self, x):
"""Result estimates.

Calculate the distance between each sample and each center,
and assign each sample to the nearest center.

Parameters
----------
x : {array-like}, shape (n_samples, n_features)
Input data for prediction.

Returns
-------
ndarray of shape (n_samples)
Returns the result of the sample for each class in the model.
"""
centers = self._centers
y = x.reshape((1, x.shape[0], x.shape[1])) - centers.reshape((centers.shape[0], 1, centers.shape[1]))
y = jnp.argmin(jnp.sum(jnp.square(y), axis=2), axis=0)
return y
54 changes: 54 additions & 0 deletions sml/kmeans/kmeans_emul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import jax.numpy as jnp

# from sklearn.metrics import roc_auc_score, explained_variance_score
import sml.utils.emulation as emulation

from sml.kmeans.kmeans import KMEANS
from sklearn.datasets import make_blobs

# TODO: design the enumation framework, just like py.unittest
# all emulation action should begin with `emul_` (for reflection)
def emul_KMEANS(mode: emulation.Mode.MULTIPROCESS):
def proc(x1, x2):
x = jnp.concatenate((x1, x2), axis=1)
model = KMEANS(
n_clusters=2,
n_samples=x.shape[0],
max_iter=10
)

return model.fit(x).predict(x)

def load_data():
n_samples = 1000
n_features = 100
X, _ = make_blobs(n_samples=n_samples,n_features=n_features,centers=2)
split_index = n_features//2
return X[:, :split_index], X[:, split_index:]

try:
# bandwidth and latency only work for docker mode
emulator = emulation.Emulator(
"examples/python/conf/3pc.json", mode, bandwidth=300, latency=20
)
emulator.up()

# load mock data
x1, x2 = load_data()
X = jnp.concatenate((x1, x2), axis=1)

# mark these data to be protected in SPU
x1, x2 = emulator.seal(x1, x2)
result = emulator.run(proc)(x1, x2)
print("result\n",result)

# Compare with sklearn
from sklearn.cluster import KMeans
model = KMeans(n_clusters=2)
print("sklearn:\n",model.fit(X).predict(X))
finally:
emulator.down()


if __name__ == "__main__":
emul_KMEANS(emulation.Mode.MULTIPROCESS)
50 changes: 50 additions & 0 deletions sml/kmeans/kmeans_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import unittest
import json
import jax.numpy as jnp

# from sklearn.metrics import roc_auc_score, explained_variance_score
import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2 # type: ignore

# TODO: unify this.
import examples.python.utils.dataset_utils as dsutil

from sml.kmeans.kmeans import KMEANS
from sklearn.datasets import make_blobs

class UnitTests(unittest.TestCase):
def test_kmeans(self):
sim = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)

def proc(x1, x2):
x = jnp.concatenate((x1, x2), axis=1)
model = KMEANS(
n_clusters=2,
n_samples=x.shape[0],
max_iter=10
)

return model.fit(x).predict(x)

def load_data():
n_samples = 1000
n_features = 100
X, _ = make_blobs(n_samples=n_samples,n_features=n_features,centers=2)
split_index = n_features//2
return X[:, :split_index], X[:, split_index:]

x1, x2 = load_data()
X = jnp.concatenate((x1, x2), axis=1)
result = spsim.sim_jax(sim, proc)(x1, x2)
print("result\n",result)

# Compare with sklearn
from sklearn.cluster import KMeans
model = KMeans(n_clusters=2)
print("sklearn:\n",model.fit(X).predict(X))


if __name__ == "__main__":
unittest.main()