Skip to content

Commit 1e23e57

Browse files
authored
add filter method to LinkGraph (#269)
* add examples to docstring * add filter method to LinkGraph * use python3.10 for code format check * add stub for networkx to github action * fix bug of accessing non-existent objects in link graph
1 parent 5d1ba6f commit 1e23e57

File tree

4 files changed

+115
-3
lines changed

4 files changed

+115
-3
lines changed

.github/workflows/format-typing-check.yml

+6-1
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,14 @@ jobs:
3030
runs-on: ubuntu-latest
3131
steps:
3232
- uses: actions/checkout@v4
33+
- name: Set up Python 3.10
34+
uses: actions/setup-python@v3
35+
with:
36+
python-version: '3.10'
3337
- name: Install ruff and mypy
3438
run: |
35-
pip install ruff mypy typing_extensions types-Deprecated types-beautifulsoup4 types-jsonschema pandas-stubs
39+
pip install ruff mypy typing_extensions \
40+
types-Deprecated types-beautifulsoup4 types-jsonschema types-networkx pandas-stubs
3641
- name: Get all changed python files
3742
id: changed-python-files
3843
uses: tj-actions/changed-files@v44

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ dev = [
5757
# static typing
5858
"mypy",
5959
"typing_extensions",
60-
# stub packages
60+
# stub packages. Update the `format-typing-check.yml` too if you add more.
6161
"types-Deprecated",
6262
"types-beautifulsoup4",
6363
"types-jsonschema",

src/nplinker/scoring/link_graph.py

+79-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
from collections.abc import Sequence
23
from functools import wraps
34
from typing import Union
45
from networkx import Graph
@@ -79,7 +80,7 @@ def __init__(self) -> None:
7980
>>> lg[gcf]
8081
{spectrum: {"metcalf": Score("metcalf", 1.0, {"cutoff": 0.5})}}
8182
82-
Get all links:
83+
Get all links in the LinkGraph:
8384
>>> lg.links
8485
[(gcf, spectrum, {"metcalf": Score("metcalf", 1.0, {"cutoff": 0.5})})]
8586
@@ -129,6 +130,10 @@ def links(
129130
130131
Returns:
131132
A list of tuples containing the links between objects.
133+
134+
Examples:
135+
>>> lg.links
136+
[(gcf, spectrum, {"metcalf": Score("metcalf", 1.0, {"cutoff": 0.5})})]
132137
"""
133138
return list(self._g.edges(data=True))
134139

@@ -150,6 +155,9 @@ def add_link(
150155
data: keyword arguments. At least one scoring method and its data must be provided.
151156
The key must be the name of the scoring method defined in `ScoringMethod`, and the
152157
value is a `Score` object, e.g. `metcalf=Score("metcalf", 1.0, {"cutoff": 0.5})`.
158+
159+
Examples:
160+
>>> lg.add_link(gcf, spectrum, metcalf=Score("metcalf", 1.0, {"cutoff": 0.5}))
153161
"""
154162
# validate the data
155163
if not data:
@@ -174,6 +182,10 @@ def has_link(self, u: Entity, v: Entity) -> bool:
174182
175183
Returns:
176184
True if there is a link between the two objects, False otherwise
185+
186+
Examples:
187+
>>> lg.has_link(gcf, spectrum)
188+
True
177189
"""
178190
return self._g.has_edge(u, v)
179191

@@ -192,5 +204,71 @@ def get_link_data(
192204
Returns:
193205
A dictionary of scoring methods and their data for the link between the two objects, or
194206
None if there is no link between the two objects.
207+
208+
Examples:
209+
>>> lg.get_link_data(gcf, spectrum)
210+
{"metcalf": Score("metcalf", 1.0, {"cutoff": 0.5})}
195211
"""
196212
return self._g.get_edge_data(u, v) # type: ignore
213+
214+
def filter(self, u_nodes: Sequence[Entity], v_nodes: Sequence[Entity] = [], /) -> LinkGraph:
215+
"""Return a new LinkGraph object with the filtered links between the given objects.
216+
217+
The new LinkGraph object will only contain the links between `u_nodes` and `v_nodes`.
218+
219+
If `u_nodes` or `v_nodes` is empty, the new LinkGraph object will contain the links for
220+
the given objects in `v_nodes` or `u_nodes`, respectively. If both are empty, return an
221+
empty LinkGraph object.
222+
223+
Note that not all objects in `u_nodes` and `v_nodes` need to be present in the original
224+
LinkGraph.
225+
226+
Args:
227+
u_nodes: a sequence of objects used as the first object in the links
228+
v_nodes: a sequence of objects used as the second object in the links
229+
230+
Returns:
231+
A new LinkGraph object with the filtered links between the given objects.
232+
233+
Examples:
234+
Filter the links for `gcf1` and `gcf2`:
235+
>>> new_lg = lg.filter([gcf1, gcf2])
236+
Filter the links for `spectrum1` and `spectrum2`:
237+
>>> new_lg = lg.filter([spectrum1, spectrum2])
238+
Filter the links between two lists of objects:
239+
>>> new_lg = lg.filter([gcf1, gcf2], [spectrum1, spectrum2])
240+
"""
241+
lg = LinkGraph()
242+
243+
# exchange u_nodes and v_nodes if u_nodes is empty but v_nodes not
244+
if len(u_nodes) == 0 and len(v_nodes) != 0:
245+
u_nodes = v_nodes
246+
v_nodes = []
247+
248+
if len(v_nodes) == 0:
249+
for u in u_nodes:
250+
self._filter_one_node(u, lg)
251+
252+
for u in u_nodes:
253+
for v in v_nodes:
254+
self._filter_two_nodes(u, v, lg)
255+
256+
return lg
257+
258+
@validate_u
259+
def _filter_one_node(self, u: Entity, lg: LinkGraph) -> None:
260+
"""Filter the links for a given object and add them to the new LinkGraph object."""
261+
try:
262+
links = self[u]
263+
except KeyError:
264+
pass
265+
else:
266+
for node2, value in links.items():
267+
lg.add_link(u, node2, **value)
268+
269+
@validate_uv
270+
def _filter_two_nodes(self, u: Entity, v: Entity, lg: LinkGraph) -> None:
271+
"""Filter the links between two objects and add them to the new LinkGraph object."""
272+
link_data = self.get_link_data(u, v)
273+
if link_data is not None:
274+
lg.add_link(u, v, **link_data)

tests/unit/scoring/test_link_graph.py

+29
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,32 @@ def test_has_link(lg, gcfs, spectra):
8383
def test_get_link_data(lg, gcfs, spectra, score):
8484
assert lg.get_link_data(gcfs[0], spectra[0]) == {"metcalf": score}
8585
assert lg.get_link_data(gcfs[0], spectra[1]) is None
86+
87+
88+
def test_filter(gcfs, spectra, score):
89+
lg = LinkGraph()
90+
lg.add_link(gcfs[0], spectra[0], metcalf=score)
91+
lg.add_link(gcfs[1], spectra[1], metcalf=score)
92+
93+
u_nodes = [gcfs[0], gcfs[1], gcfs[2]]
94+
v_nodes = [spectra[0], spectra[1], spectra[2]]
95+
96+
# test filtering with GCFs
97+
lg_filtered = lg.filter(u_nodes)
98+
assert len(lg_filtered) == 4 # number of nodes
99+
100+
# test filtering with Spectra
101+
lg_filtered = lg.filter(v_nodes)
102+
assert len(lg_filtered) == 4
103+
104+
# test empty `u_nodes` argument
105+
lg_filtered = lg.filter([], v_nodes)
106+
assert len(lg_filtered) == 4
107+
108+
# test empty `u_nodes` and `v_nodes` arguments
109+
lg_filtered = lg.filter([], [])
110+
assert len(lg_filtered) == 0
111+
112+
# test filtering with GCFs and Spectra
113+
lg_filtered = lg.filter(u_nodes, v_nodes)
114+
assert len(lg_filtered) == 4

0 commit comments

Comments
 (0)