Skip to content

Commit aa02ee4

Browse files
committed
Add doc strings and type hints to tree.py. Also optimize prime factor fn
1 parent d5fa077 commit aa02ee4

File tree

1 file changed

+195
-59
lines changed

1 file changed

+195
-59
lines changed

gerrychain/tree.py

+195-59
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from inspect import signature
66
from .random import random
77
from collections import deque, namedtuple
8-
from typing import Any, Callable, Dict, List, Optional, Set, Union, Sequence
8+
from typing import Any, Callable, Dict, List, Optional, Set, Union, Hashable, Sequence, Tuple
99

1010

1111
def predecessors(h: nx.Graph, root: Any) -> Dict:
@@ -45,11 +45,17 @@ def random_spanning_tree(graph: nx.Graph, weight_dict: Dict) -> nx.Graph:
4545
return spanning_tree
4646

4747

48-
def uniform_spanning_tree(graph: nx.Graph, choice: Callable = random.choice) -> nx.Graph:
49-
""" Builds a spanning tree chosen uniformly from the space of all
50-
spanning trees of the graph.
51-
:param graph: Networkx Graph
52-
:param choice: :func:`random.choice`
48+
def uniform_spanning_tree(
49+
graph: nx.Graph,
50+
choice: Callable = random.choice
51+
) -> nx.Graph:
52+
"""
53+
Builds a spanning tree chosen uniformly from the space of all
54+
spanning trees of the graph. Uses Wilson's algorithm.
55+
:param graph: Networkx Graph
56+
:type graph: nx.Graph
57+
:param choice: :func:`random.choice`. Defaults to :func:`random.choice`.
58+
:type choice: Callable, optional
5359
"""
5460
root = choice(list(graph.node_indices))
5561
tree_nodes = set([root])
@@ -75,6 +81,19 @@ def uniform_spanning_tree(graph: nx.Graph, choice: Callable = random.choice) ->
7581

7682

7783
class PopulatedGraph:
84+
"""
85+
A class representing a graph with population information.
86+
87+
:param graph: The underlying graph structure.
88+
:type graph: nx.Graph
89+
:param populations: A dictionary mapping nodes to their populations.
90+
:type populations: Dict
91+
:param ideal_pop: The ideal population for each district.
92+
:type ideal_pop: float
93+
:param epsilon: The tolerance for population deviation from the ideal population within each
94+
district.
95+
:type epsilon: float
96+
"""
7897
def __init__(
7998
self,
8099
graph: nx.Graph,
@@ -107,12 +126,27 @@ def has_ideal_population(self, node) -> bool:
107126
)
108127

109128

129+
130+
# Tuple that is used in the find_balanced_edge_cuts function
131+
# Comment added to make this easier to find
110132
Cut = namedtuple("Cut", "edge subset")
111133

112134

113135
def find_balanced_edge_cuts_contraction(
114-
h: PopulatedGraph, choice: Callable = random.choice) -> List[Cut]:
115-
# this used to be greater than 2 but failed on small grids:(
136+
h: PopulatedGraph,
137+
choice: Callable = random.choice
138+
) -> List[Cut]:
139+
"""
140+
Find balanced edge cuts using contraction.
141+
142+
:param h: The populated graph.
143+
:type h: PopulatedGraph
144+
:param choice: The function used to make random choices.
145+
:type choice: Callable, optional
146+
:return: A list of balanced edge cuts.
147+
:rtype: List[Cut]
148+
"""
149+
116150
root = choice([x for x in h if h.degree(x) > 1])
117151
# BFS predecessors for iteratively contracting leaves
118152
pred = predecessors(h.graph, root)
@@ -135,6 +169,21 @@ def find_balanced_edge_cuts_memoization(
135169
h: PopulatedGraph,
136170
choice: Callable = random.choice
137171
) -> List[Any]:
172+
"""
173+
Find balanced edge cuts using memoization.
174+
175+
This function takes a PopulatedGraph object and a choice function as input and returns a list of balanced edge cuts.
176+
A balanced edge cut is defined as a cut that divides the graph into two subsets, such that the population of each subset
177+
is close to the ideal population defined by the PopulatedGraph object.
178+
179+
:param h: The PopulatedGraph object representing the graph.
180+
:type h: PopulatedGraph
181+
:param choice: The choice function used to select the root node.
182+
:type choice: Callable, optional
183+
:return: A list of balanced edge cuts.
184+
:rtype: List[Any]
185+
"""
186+
138187
root = choice([x for x in h if h.degree(x) > 1])
139188
pred = predecessors(h.graph, root)
140189
succ = successors(h.graph, root)
@@ -193,7 +242,6 @@ def bipartition_tree(
193242
balance_edge_fn: Callable = find_balanced_edge_cuts_memoization,
194243
choice: Callable = random.choice,
195244
max_attempts: Optional[int] = 10000
196-
max_attempts: Optional[int] = None
197245
) -> Set:
198246
"""
199247
This function finds a balanced 2 partition of a graph by drawing a
@@ -280,8 +328,40 @@ def _bipartition_tree_random_all(
280328
balance_edge_fn: Callable = find_balanced_edge_cuts_memoization,
281329
choice: Callable = random.choice,
282330
max_attempts: Optional[int] = None
283-
):
284-
"""Randomly bipartitions a graph and returns all cuts."""
331+
) -> List[Tuple[Hashable, Hashable]]:
332+
"""
333+
Randomly bipartitions a tree into two subgraphs until a valid bipartition is found.
334+
335+
:param graph: The input graph.
336+
:type graph: nx.Graph
337+
:param pop_col: The name of the column in the graph nodes that contains the population data.
338+
:type pop_col: str
339+
:param pop_target: The target population for each subgraph.
340+
:type pop_target: Union[int, float]
341+
:param epsilon: The allowed deviation from the target population.
342+
:type epsilon: float
343+
:param node_repeats: The number of times to repeat the bipartitioning process. Defaults to 1.
344+
:type node_repeats: int, optional
345+
:param repeat_until_valid: Whether to repeat the bipartitioning process until a valid bipartition is found. Defaults to True.
346+
:type repeat_until_valid: bool, optional
347+
:param spanning_tree: The spanning tree to use for bipartitioning. If None, a random spanning tree will be generated. Defaults to None.
348+
:type spanning_tree: Optional[nx.Graph], optional
349+
:param spanning_tree_fn: The function to generate a spanning tree. Defaults to random_spanning_tree.
350+
:type spanning_tree_fn: Callable, optional
351+
:param balance_edge_fn: The function to find balanced edge cuts. Defaults to find_balanced_edge_cuts_memoization.
352+
:type balance_edge_fn: Callable, optional
353+
:param choice: The function to choose a random element from a list. Defaults to random.choice.
354+
:type choice: Callable, optional
355+
:param max_attempts: The maximum number of attempts to find a valid bipartition. If None, there is no limit. Defaults to None.
356+
:type max_attempts: Optional[int], optional
357+
358+
:returns: A list of possible cuts that bipartition the tree into two subgraphs.
359+
:rtype: List[Tuple[Hashable, Hashable]]
360+
361+
:raises RuntimeError: If a valid bipartition cannot be found after the specified number of attempts.
362+
"""
363+
364+
285365
populations = {node: graph.nodes[node][pop_col] for node in graph.node_indices}
286366

287367
possible_cuts = []
@@ -321,8 +401,9 @@ def bipartition_tree_random(
321401
balance_edge_fn: Callable = find_balanced_edge_cuts_memoization,
322402
choice: Callable = random.choice,
323403
max_attempts: Optional[int] = None
324-
):
325-
"""This is like :func:`bipartition_tree` except it chooses a random balanced
404+
) -> Union[Set[Any], None]:
405+
"""
406+
This is like :func:`bipartition_tree` except it chooses a random balanced
326407
cut, rather than the first cut it finds.
327408
328409
This function finds a balanced 2 partition of a graph by drawing a
@@ -333,27 +414,38 @@ def bipartition_tree_random(
333414
Builds up a connected subgraph with a connected complement whose population
334415
is ``epsilon * pop_target`` away from ``pop_target``.
335416
336-
Returns a subset of nodes of ``graph`` (whose induced subgraph is connected).
337-
The other part of the partition is the complement of this subset.
338-
339-
:param graph: The graph to partition
340-
:param pop_col: The node attribute holding the population of each node
341-
:param pop_target: The target population for the returned subset of nodes
417+
:param graph: The graph to partition (must be an instance of nx.Graph)
418+
:type graph: nx.Graph
419+
:param pop_col: The node attribute holding the population of each node (must be a string)
420+
:type pop_col: str
421+
:param pop_target: The target population for the returned subset of nodes (must be an int or float)
422+
:type pop_target: Union[int, float]
342423
:param epsilon: The allowable deviation from ``pop_target`` (as a percentage of
343-
``pop_target``) for the subgraph's population
424+
``pop_target``) for the subgraph's population (must be a float)
425+
:type epsilon: float
344426
:param node_repeats: A parameter for the algorithm: how many different choices
345-
of root to use before drawing a new spanning tree.
427+
of root to use before drawing a new spanning tree (default is 1, must be an int)
428+
:type node_repeats: int
346429
:param repeat_until_valid: Determines whether to keep drawing spanning trees
347430
until a tree with a balanced cut is found. If `True`, a set of nodes will
348431
always be returned; if `False`, `None` will be returned if a valid spanning
349-
tree is not found on the first try.
432+
tree is not found on the first try (default is True, must be a bool)
433+
:type repeat_until_valid: bool
350434
:param spanning_tree: The spanning tree for the algorithm to use (used when the
351-
algorithm chooses a new root and for testing)
435+
algorithm chooses a new root and for testing) (must be an instance of nx.Graph or None)
436+
:type spanning_tree: Optional[nx.Graph]
352437
:param spanning_tree_fn: The random spanning tree algorithm to use if a spanning
353-
tree is not provided
354-
:param balance_edge_fn: The algorithm used to find balanced cut edges
355-
:param choice: :func:`random.choice`. Can be substituted for testing.
356-
:param max_atempts: The max number of attempts that should be made to bipartition.
438+
tree is not provided (must be a callable)
439+
:type spanning_tree_fn: Callable
440+
:param balance_edge_fn: The algorithm used to find balanced cut edges (must be a callable)
441+
:type balance_edge_fn: Callable
442+
:param choice: :func:`random.choice`. Can be substituted for testing. (must be a callable)
443+
:type choice: Callable
444+
:param max_attempts: The max number of attempts that should be made to bipartition. (must be an int or None)
445+
:type max_attempts: Optional[int]
446+
447+
:return: A subset of nodes of ``graph`` (whose induced subgraph is connected) or None if a valid spanning tree is not found.
448+
:rtype: Union[Set[Any], None]
357449
"""
358450
possible_cuts = _bipartition_tree_random_all(
359451
graph=graph,
@@ -381,18 +473,28 @@ def recursive_tree_part(
381473
node_repeats: int = 1,
382474
method: Callable = partial(bipartition_tree, max_attempts=10000)
383475
) -> Dict:
384-
"""Uses :func:`~gerrychain.tree.bipartition_tree` recursively to partition a tree into
476+
"""
477+
Uses :func:`~gerrychain.tree.bipartition_tree` recursively to partition a tree into
385478
``len(parts)`` parts of population ``pop_target`` (within ``epsilon``). Can be used to
386479
generate initial seed plans or to implement ReCom-like "merge walk" proposals.
387480
388481
:param graph: The graph
389-
:param parts: Iterable of part labels (like ``[0,1,2]`` or ``range(4)``
482+
:type graph: nx.Graph
483+
:param parts: Iterable of part labels (like ``[0,1,2]`` or ``range(4)``)
484+
:type parts: Sequence
390485
:param pop_target: Target population for each part of the partition
486+
:type pop_target: Union[float, int]
391487
:param pop_col: Node attribute key holding population data
488+
:type pop_col: str
392489
:param epsilon: How far (as a percentage of ``pop_target``) from ``pop_target`` the parts
393490
of the partition can be
491+
:type epsilon: float
394492
:param node_repeats: Parameter for :func:`~gerrychain.tree_methods.bipartition_tree` to use.
395-
:param method: The partition method to use.
493+
Defaluts to 1.
494+
:type node_repeats: int, optional
495+
:param method: The partition method to use. Defaults to
496+
`partial(bipartition_tree, max_attempts=10000)`.
497+
:type method: Callable, optional
396498
:return: New assignments for the nodes of ``graph``.
397499
:rtype: dict
398500
"""
@@ -452,12 +554,24 @@ def get_seed_chunks(
452554
balanced within new_epsilon <= ``epsilon`` of a balanced target population.
453555
454556
:param graph: The graph
455-
:param parts: Iterable of part labels (like ``[0,1,2]`` or ``range(4)``
456-
:param pop_target: target population of the districts (not of the chunks)
557+
:type graph: nx.Graph
558+
:param num_chunks: The number of chunks to partition the graph into
559+
:type num_chunks: int
560+
:param num_dists: The number of districts
561+
:type num_dists: int
562+
:param pop_target: The target population of the districts (not of the chunks)
563+
:type pop_target: Union[int, float]
457564
:param pop_col: Node attribute key holding population data
565+
:type pop_col: str
458566
:param epsilon: How far (as a percentage of ``pop_target``) from ``pop_target`` the parts
459567
of the partition can be
460-
:param node_repeats: Parameter for :func:`~gerrychain.tree_methods.bipartition_tree` to use.
568+
:type epsilon: float
569+
:param node_repeats: Parameter for :func:`~gerrychain.tree_methods.bipartition_tree_random`
570+
to use.
571+
:type node_repeats: int, optional
572+
:param method: The method to use for bipartitioning the graph.
573+
Defaults to :func:`~gerrychain.tree_methods.bipartition_tree_random`
574+
:type method: Callable, optional
461575
:return: New assignments for the nodes of ``graph``.
462576
:rtype: dict
463577
"""
@@ -534,29 +648,40 @@ def get_seed_chunks(
534648

535649

536650
def get_max_prime_factor_less_than(
537-
n, ceil
538-
):
651+
n: int, ceil: int
652+
) -> Optional[int]:
539653
"""
540-
Helper function for recursive_seed_part. Returns the largest prime factor of ``n`` less than
654+
Helper function for recursive_seed_part_inner. Returns the largest prime factor of ``n`` less than
541655
``ceil``, or None if all are greater than ceil.
656+
657+
:param n: The number to find the largest prime factor for.
658+
:type n: int
659+
:param ceil: The upper limit for the largest prime factor.
660+
:type ceil: int
661+
:return: The largest prime factor of ``n`` less than ``ceil``, or None if all are greater than ceil.
662+
:rtype: int or None
542663
"""
543-
factors = []
544-
i = 2
664+
if n <= 1 or ceil <= 1:
665+
return None
666+
667+
largest_factor = None
668+
while n % 2 == 0:
669+
largest_factor = 2
670+
n //= 2
671+
672+
i = 3
545673
while i * i <= n:
546-
if n % i:
547-
i += 1
548-
else:
674+
while n % i == 0:
675+
if i <= ceil:
676+
largest_factor = i
549677
n //= i
550-
factors.append(i)
551-
if n > 1:
552-
factors.append(n)
553-
554-
if len(factors) == 0:
555-
return 1
556-
m = [i for i in factors if i <= ceil]
557-
if m == []:
558-
return None
559-
return int(max(m))
678+
i += 2
679+
680+
if n > 1 and n <= ceil:
681+
largest_factor = n
682+
683+
return largest_factor
684+
560685

561686

562687
def recursive_seed_part_inner(
@@ -681,29 +806,40 @@ def recursive_seed_part(
681806
method: Callable = partial(bipartition_tree, max_attempts=10000),
682807
node_repeats: int = 1,
683808
n: Optional[int] = None,
684-
ceil: None = None
809+
ceil: Optional[int] = None
685810
) -> Dict:
686811
"""
687812
Returns a partition with ``num_dists`` districts balanced within ``epsilon`` of
688813
``pop_target`` by recursively splitting graph using recursive_seed_part_inner.
689814
690815
:param graph: The graph
816+
:type graph: nx.Graph
691817
:param parts: Iterable of part labels (like ``[0,1,2]`` or ``range(4)``
818+
:type parts: Sequence
692819
:param pop_target: Target population for each part of the partition
820+
:type pop_target: Union[float, int]
693821
:param pop_col: Node attribute key holding population data
822+
:type pop_col: str
694823
:param epsilon: How far (as a percentage of ``pop_target``) from ``pop_target`` the parts
695824
of the partition can be
825+
:type epsilon: float
696826
:param method: Function used to find balanced partitions at the 2-district level
827+
Defaults to :func:`~gerrychain.tree_methods.bipartition_tree`
828+
:type method: Callable
697829
:param node_repeats: Parameter for :func:`~gerrychain.tree_methods.bipartition_tree` to use.
830+
Defaults to 1.
831+
:type node_repeats: int, optional
698832
:param n: Either a positive integer (greater than 1) or None. If n is a positive integer,
699-
this function will recursively create a seed plan by either biting off districts from graph
700-
or dividing graph into n chunks and recursing into each of these. If n is None, this
701-
function prime factors ``num_dists``=n_1*n_2*...*n_k (n_1 > n_2 > ... n_k) and recursively
702-
partitions graph into n_1 chunks.
833+
this function will recursively create a seed plan by either biting off districts from graph
834+
or dividing graph into n chunks and recursing into each of these. If n is None, this
835+
function prime factors ``num_dists``=n_1*n_2*...*n_k (n_1 > n_2 > ... n_k) and recursively
836+
partitions graph into n_1 chunks.
837+
:type n: Optional[int]
703838
:param ceil: Either a positive integer (at least 2) or None. Relevant only if n is None. If
704-
``ceil`` is a positive integer then finds the largest factor of ``num_dists`` less than or
705-
equal to ``ceil``, and recursively splits graph into that number of chunks, or bites off a
706-
district if that number is 1.
839+
``ceil`` is a positive integer then finds the largest factor of ``num_dists`` less than or
840+
equal to ``ceil``, and recursively splits graph into that number of chunks, or bites off a
841+
district if that number is 1. Defaults to None.
842+
:type ceil: Optional[int]
707843
:return: New assignments for the nodes of ``graph``.
708844
:rtype: dict
709845
"""

0 commit comments

Comments
 (0)