Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add check edge weights positive #28

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion dynnode2vec/dynnode2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
n_walks_per_node: int = 10,
embedding_size: int = 128,
window: int = 10,
weighted: bool = False,
seed: int | None = 0,
parallel_processes: int = 4,
plain_node2vec: bool = False,
Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(
assert (
isinstance(window, int) and embedding_size > 0
), "window should be a strictly positive integer"
assert isinstance(weighted, bool), "weighted should be a boolean"
assert (
seed is None or isinstance(seed, int)
) and embedding_size > 0, "seed should be either None or int"
Expand All @@ -84,13 +86,29 @@ def __init__(
self.n_walks_per_node = n_walks_per_node
self.embedding_size = embedding_size
self.window = window
self.weighted = weighted
self.seed = seed
self.parallel_processes = parallel_processes
self.plain_node2vec = plain_node2vec

# see https://stackoverflow.com/questions/53417258/what-is-workers-parameter-in-word2vec-in-nlp # pylint: disable=line-too-long
self.gensim_workers = max(self.parallel_processes - 1, 12)

def _check_edge_weights(self, graphs: list[nx.Graph]) -> None:
"""
Check that all edge weights are strictly positive, otherwise we can not run random walks.
"""
if not self.weighted:
return

for i, graph in enumerate(graphs):
weights = nx.get_edge_attributes(graph, name="weight")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also check that weights are non-empty for all nodes in all graphs?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes good point !


assert all(weight > 0 for weight in weights.values()), (
"All edge weights should be strictly positive to run Dynnode2Vec "
f"found negative weight in graph {i}"
)

def _initialize_embeddings(
self, graphs: list[nx.Graph]
) -> tuple[Word2Vec, list[Embedding]]:
Expand Down Expand Up @@ -232,7 +250,7 @@ def compute_embeddings(self, graphs: list[nx.Graph]) -> list[Embedding]:
"""
Compute dynamic embeddings on a list of graphs.
"""
# TO DO : check graph weights valid
self._check_edge_weights(graphs)
model, embeddings = self._initialize_embeddings(graphs)
time_walks = self._simulate_walks(graphs)
self._update_embeddings(embeddings, time_walks, model)
Expand Down
24 changes: 24 additions & 0 deletions tests/test_dynnode2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Test the DynNode2Vec class
"""
# pylint: disable=missing-function-docstring
import random

import gensim
import networkx as nx
Expand All @@ -24,6 +25,13 @@ def dynnode2vec_fixture():
)


@pytest.fixture(name="weighted_dynnode2vec_object")
def weighted_dynnode2vec_fixture():
return dynnode2vec.DynNode2Vec(
n_walks_per_node=5, walk_length=5, weighted=True, parallel_processes=1
)


@pytest.fixture(name="parallel_dynnode2vec_object")
def dynnode2vec_parallel_fixture():
return dynnode2vec.DynNode2Vec(
Expand Down Expand Up @@ -93,6 +101,22 @@ def test_compute_embeddings(graphs, dynnode2vec_object):
assert all(isinstance(emb, dynnode2vec.Embedding) for emb in embeddings)


def test_compute_weighted_embeddings(graphs, weighted_dynnode2vec_object):
embeddings = weighted_dynnode2vec_object.compute_embeddings(graphs)

assert isinstance(embeddings, list)
assert all(isinstance(emb, dynnode2vec.Embedding) for emb in embeddings)

# add random negative weights to the graph and check that it raises
rng = random.Random(0)
for graph in graphs:
for _, _, data in graph.edges(data=True):
data["weight"] = -rng.random()

with pytest.raises(AssertionError):
weighted_dynnode2vec_object.compute_embeddings(graphs)


def test_parallel_compute_embeddings(graphs, parallel_dynnode2vec_object):
embeddings = parallel_dynnode2vec_object.compute_embeddings(graphs)

Expand Down