Skip to content

Commit d5fa077

Browse files
committed
Add region_aware stuff
1 parent 627e6c6 commit d5fa077

File tree

3 files changed

+77
-30
lines changed

3 files changed

+77
-30
lines changed

.gitignore

+4-1
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,7 @@ junit.xml
7474

7575
# Pytest cache
7676
.pytest_cache/
77-
Dockerfile
77+
Dockerfile
78+
79+
# Extra Documentation Stuff
80+
release_notes.md

gerrychain/proposals/tree_proposals.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from functools import partial
2+
from inspect import signature
23
from ..random import random
34

45
from ..tree import (
@@ -9,7 +10,9 @@
910

1011

1112
def recom(
12-
partition, pop_col, pop_target, epsilon, node_repeats=1, method=bipartition_tree
13+
partition, pop_col, pop_target, epsilon, node_repeats=1,
14+
weight_dict = None,
15+
method=bipartition_tree
1316
):
1417
"""ReCom proposal.
1518
@@ -45,6 +48,11 @@ def recom(
4548
partition.parts[parts_to_merge[0]] | partition.parts[parts_to_merge[1]]
4649
)
4750

51+
# Try to add the region aware in if the method accepts the weight dictionary
52+
if 'weight_dict' in signature(method).parameters:
53+
method = partial(method, weight_dict=weight_dict)
54+
55+
4856
flips = recursive_tree_part(
4957
subgraph.graph,
5058
parts_to_merge,

gerrychain/tree.py

+64-28
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from networkx.algorithms import tree
33

44
from functools import partial
5+
from inspect import signature
56
from .random import random
67
from collections import deque, namedtuple
78
from typing import Any, Callable, Dict, List, Optional, Set, Union, Sequence
@@ -15,19 +16,28 @@ def successors(h: nx.Graph, root: Any) -> Dict:
1516
return {a: b for a, b in nx.bfs_successors(h, root)}
1617

1718

18-
def random_spanning_tree(graph: nx.Graph) -> nx.Graph:
19-
""" Builds a spanning tree chosen by Kruskal's method using random weights.
20-
:param graph: FrozenGraph
21-
22-
Important Note:
23-
The key is specifically labelled "random_weight" instead of the previously
24-
used "weight". Turns out that networkx uses the "weight" keyword for other
25-
operations, like when computing the laplacian or the adjacency matrix.
26-
This meant that the laplacian would change for the graph step to step,
27-
something that we do not intend!!
19+
def random_spanning_tree(graph: nx.Graph, weight_dict: Dict) -> nx.Graph:
20+
"""
21+
Builds a spanning tree chosen by Kruskal's method using random weights.
22+
23+
:param graph: The input graph to build the spanning tree from. Should be a Networkx Graph.
24+
:type graph: nx.Graph
25+
:param weight_dict: Dictionary of weights to add to the random weights used in region-aware variants.
26+
:type weight_dict: Dict
27+
:return: The maximal spanning tree represented as a Networkx Graph.
28+
:rtype: nx.Graph
2829
"""
29-
for edge in graph.edge_indices:
30-
graph.edges[edge]["random_weight"] = random.random()
30+
if weight_dict is None:
31+
weight_dict = dict()
32+
33+
for edge in graph.edges():
34+
weight = random.random()
35+
for key, value in weight_dict.items():
36+
if graph.nodes[edge[0]][key] == graph.nodes[edge[1]][key] and \
37+
graph.nodes[edge[0]][key] is not None:
38+
weight += value
39+
40+
graph.edges[edge]["random_weight"] = weight
3141

3242
spanning_tree = tree.maximum_spanning_tree(
3343
graph, algorithm="kruskal", weight="random_weight"
@@ -179,35 +189,61 @@ def bipartition_tree(
179189
node_repeats: int = 1,
180190
spanning_tree: Optional[nx.Graph] = None,
181191
spanning_tree_fn: Callable = random_spanning_tree,
192+
weight_dict: Dict = None,
182193
balance_edge_fn: Callable = find_balanced_edge_cuts_memoization,
183194
choice: Callable = random.choice,
195+
max_attempts: Optional[int] = 10000
184196
max_attempts: Optional[int] = None
185197
) -> Set:
186-
"""This function finds a balanced 2 partition of a graph by drawing a
198+
"""
199+
This function finds a balanced 2 partition of a graph by drawing a
187200
spanning tree and finding an edge to cut that leaves at most an epsilon
188201
imbalance between the populations of the parts. If a root fails, new roots
189202
are tried until node_repeats in which case a new tree is drawn.
190203
191204
Builds up a connected subgraph with a connected complement whose population
192205
is ``epsilon * pop_target`` away from ``pop_target``.
193206
194-
Returns a subset of nodes of ``graph`` (whose induced subgraph is connected).
195-
The other part of the partition is the complement of this subset.
196207
197-
:param graph: The graph to partition
198-
:param pop_col: The node attribute holding the population of each node
199-
:param pop_target: The target population for the returned subset of nodes
200-
:param epsilon: The allowable deviation from ``pop_target`` (as a percentage of
201-
``pop_target``) for the subgraph's population
202-
:param node_repeats: A parameter for the algorithm: how many different choices
203-
of root to use before drawing a new spanning tree.
204-
:param spanning_tree: The spanning tree for the algorithm to use (used when the
205-
algorithm chooses a new root and for testing)
206-
:param spanning_tree_fn: The random spanning tree algorithm to use if a spanning
207-
tree is not provided
208-
:param choice: :func:`random.choice`. Can be substituted for testing.
209-
:param max_atempts: The max number of attempts that should be made to bipartition.
208+
:param graph: The graph to partition.
209+
:type graph: nx.Graph
210+
:param pop_col: The node attribute holding the population of each node.
211+
:type pop_col: str
212+
:param pop_target: The target population for the returned subset of nodes.
213+
:type pop_target: Union[int, float]
214+
:param epsilon: The allowable deviation from ``pop_target`` (as a percentage of
215+
``pop_target``) for the subgraph's population.
216+
:type epsilon: float
217+
:param node_repeats: A parameter for the algorithm: how many different choices
218+
of root to use before drawing a new spanning tree. Defaults to 1.
219+
:type node_repeats: int
220+
:param spanning_tree: The spanning tree for the algorithm to use (used when the
221+
algorithm chooses a new root and for testing).
222+
:type spanning_tree: Optional[nx.Graph]
223+
:param spanning_tree_fn: The random spanning tree algorithm to use if a spanning
224+
tree is not provided. Defaults to :func:`random_spanning_tree`.
225+
:type spanning_tree_fn: Callable
226+
:param weight_dict: A dictionary of weights for the spanning tree algorithm.
227+
Defaults to None.
228+
:type weight_dict: Dict, optional
229+
:param balance_edge_fn: The function to find balanced edge cuts. Defaults to
230+
:func:`find_balanced_edge_cuts_memoization`.
231+
:type balance_edge_fn: Callable, optional
232+
:param choice: The function to make a random choice. Can be substituted for testing.
233+
Defaults to :func:`random.choice`.
234+
:type choice: Callable
235+
:param max_attempts: The maximum number of attempts that should be made to bipartition.
236+
Defaults to 1000.
237+
:type max_attempts: Optional[int]
238+
:return: A subset of nodes of ``graph`` (whose induced subgraph is connected). The other
239+
part of the partition is the complement of this subset.
240+
:rtype: Set
241+
:raises RuntimeError: If a possible cut cannot be found after the maximum number of attempts.
210242
"""
243+
# Try to add the region-aware in if the spanning_tree_fn accepts a weight dictionary
244+
if 'weight_dict' in signature(spanning_tree_fn).parameters:
245+
spanning_tree_fn = partial(spanning_tree_fn, weight_dict=weight_dict)
246+
211247
populations = {node: graph.nodes[node][pop_col] for node in graph.node_indices}
212248

213249
possible_cuts = []

0 commit comments

Comments
 (0)