Skip to content

Commit

Permalink
add class
Browse files Browse the repository at this point in the history
  • Loading branch information
cpz2024 committed Mar 10, 2025
1 parent c253afb commit 906ea22
Show file tree
Hide file tree
Showing 21 changed files with 879 additions and 401 deletions.
4 changes: 2 additions & 2 deletions MODULE.bazel.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 5 additions & 7 deletions sml/manifold/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.


load("//bazel:spu.bzl", "sml_py_library")

package(default_visibility = ["//visibility:public"])


sml_py_library(
name = "jacobi",
srcs = ["jacobi.py"],
Expand All @@ -29,8 +27,8 @@ sml_py_library(
)

sml_py_library(
name = "MDS",
srcs = ["MDS.py"],
name = "isomap",
srcs = ["isomap.py"],
)

sml_py_library(
Expand All @@ -44,6 +42,6 @@ sml_py_library(
)

sml_py_library(
name = "SE",
srcs = ["SE.py"],
)
name = "se",
srcs = ["se.py"],
)
58 changes: 0 additions & 58 deletions sml/manifold/MDS.py

This file was deleted.

53 changes: 0 additions & 53 deletions sml/manifold/SE.py

This file was deleted.

99 changes: 63 additions & 36 deletions sml/manifold/dijkstra.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,61 +15,93 @@


def set_value(x, index, value, n):
# Change the value at the index of array x to value, where index is secretly shared
# n: the length of array x

perm = jnp.zeros(n, dtype=jnp.int16)
perm_2 = jnp.zeros(n, dtype=jnp.int16)
for i in range(n):
perm = perm.at[i].set(i)
perm_2 = perm_2.at[i].set(index)
flag = jnp.equal(perm, perm_2)
set_x = jnp.select([flag], [value], x)
"""Change the value at the specified index of array x to a given value, where the index is secretly shared.
Args:
x: The input array to be modified.
index: The index at which the value should be set (secretly shared).
value: The new value to set at the specified index.
n: The length of the array x.
Returns:
The modified array with the value updated at the specified index.
"""
perm = jnp.arange(n)
perm_2 = jnp.ones(n) * index

set_x = jnp.where(perm == perm_2, value, x)

return set_x


def get_value_1(x, index, n):
# Obtain the value at the x[index] index, where index is a secret shared value
# n: the length of array x

perm = jnp.zeros(n, dtype=jnp.int16)
perm_2 = jnp.zeros(n, dtype=jnp.int16)
for i in range(n):
perm = perm.at[i].set(i)
perm_2 = perm_2.at[i].set(index)
"""Retrieve the value at the specified index of array x, where the index is secretly shared.
Args:
x: The input array from which to retrieve the value.
index: The index to retrieve the value from (secretly shared).
n: The length of the array x.
Returns:
The value at the specified index.
"""

perm = jnp.arange(n)
perm_2 = jnp.ones(n) * index
flag = jnp.equal(perm, perm_2)
return jnp.sum(flag * x)


def get_value_2(x, index_1, index_2, n):
# Obtain the value at index x[index_1][index_2], where index_2 is plaintext and index_1 is secret shared
# n: the length of array x
"""Retrieve the value at the specified 2D index of array x, where index_1 is secretly shared and index_2 is plaintext.
Args:
x: The input 2D array from which to retrieve the value.
index_1: The row index (secretly shared).
index_2: The column index (plaintext).
n: The size of the array x (assuming it is square).
Returns:
The value at the specified 2D index.
"""

# Initialize row index
perm_1 = jnp.zeros((n, n), dtype=jnp.int16)
perm_2_row = jnp.zeros((n, n), dtype=jnp.int16)
perm_1 = jnp.arange(n)[:, None]
perm_1 = jnp.tile(perm_1, (1, index_2 + 1))

for i in range(n):
for j in range(n):
perm_1 = perm_1.at[i, j].set(i)
perm_2_row = perm_2_row.at[i, j].set(index_1)
perm_2_row = jnp.ones((n, index_2 + 1)) * index_1

# Match rows
flag_row = jnp.equal(perm_1, perm_2_row)

# Extract column values directly using plaintext index_2
flag = flag_row[:, index_2]

# Return the value at the matching index
return jnp.sum(flag * x[:, index_2])


def mpc_dijkstra(adj_matrix, num_samples, start, dist_inf):
# adj_matrix: the adjacency matrix for calculating shortest path
# num_samples:The size of the adjacency matrix
# start:To calculate the shortest path for all point-to-point starts
# dis_inf:The initial shortest path for all point-to-point starts, set as inf
"""Use Dijkstra's algorithm to compute the shortest paths from the starting point to all other points.
Parameters
----------
adj_matrix : ndarray
The adjacency matrix used to compute the shortest paths.
num_samples : int
The size of the adjacency matrix (number of nodes).
start : int
The starting point for which to compute the shortest paths.
dist_inf : ndarray
The initialized shortest path array, usually set to infinity (inf).
Returns
-------
distances : ndarray
The shortest paths from the starting point to all other points.
"""

# Initialize with Inf value
sinf = dist_inf[0]
Expand All @@ -90,12 +122,7 @@ def mpc_dijkstra(adj_matrix, num_samples, start, dist_inf):
flag = (visited[v] == 0) * (distances[v] < min_distance)
min_distance = min_distance + flag * (distances[v] - min_distance)
min_index = min_index + flag * (v - min_index)
# min_distance = jax.lax.cond(flag, lambda _: distances[v], lambda _: min_distance)
# min_index = jax.lax.cond(flag, lambda _: v, lambda _: min_index)

# Mark as visited
# jax.lax.dynamic_update_slice(visited, 1, (min_index,))
# visited[min_index] = True
visited = set_value(visited, min_index, True, num_samples)

# Update the distance between adjacent nodes
Expand Down
29 changes: 4 additions & 25 deletions sml/manifold/emulations/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.


load("//bazel:spu.bzl", "sml_py_binary")

package(default_visibility = ["//visibility:public"])

sml_py_binary(
name = "Isomap_emul",
srcs = ["Isomap_emul.py"],
name = "isomap_emul",
srcs = ["isomap_emul.py"],
deps = [
"//sml/manifold:MDS",
"//sml/manifold:SE",
"//sml/manifold:dijkstra",
"//sml/manifold:floyd",
"//sml/manifold:isomap",
"//sml/manifold:jacobi",
"//sml/manifold:kneighbors",
"//spu/intrinsic:all_intrinsics",
Expand All @@ -35,27 +32,9 @@ sml_py_binary(
name = "se_emul",
srcs = ["se_emul.py"],
deps = [
"//sml/manifold:MDS",
"//sml/manifold:SE",
"//sml/manifold:dijkstra",
"//sml/manifold:floyd",
"//sml/manifold:jacobi",
"//sml/manifold:kneighbors",
"//spu/intrinsic:all_intrinsics",
],
)


sml_py_binary(
name = "knn_emul",
srcs = ["knn_emul.py"],
deps = [
"//sml/manifold:MDS",
"//sml/manifold:SE",
"//sml/manifold:dijkstra",
"//sml/manifold:floyd",
"//sml/manifold:jacobi",
"//sml/manifold:kneighbors",
"//sml/manifold:se",
"//spu/intrinsic:all_intrinsics",
],
)
Loading

0 comments on commit 906ea22

Please sign in to comment.