35
35
import random
36
36
from collections import deque , namedtuple
37
37
from typing import Any , Callable , Dict , List , Optional , Set , Union , Hashable , Sequence , Tuple
38
+ import warnings
38
39
39
40
40
41
def predecessors (h : nx .Graph , root : Any ) -> Dict :
@@ -295,6 +296,22 @@ def part_nodes(start):
295
296
return cuts
296
297
297
298
299
+ class BipartitionWarning (UserWarning ):
300
+ """
301
+ Generally raised when it is proving difficult to find a balanced cut.
302
+ """
303
+ pass
304
+
305
+
306
+ class ReselectException (Exception ):
307
+ """
308
+ Raised when the algorithm is unable to find a balanced cut after some
309
+ maximum number of attempts, but the user has allowed the algorithm to
310
+ reselect the pair of nodes to try and recombine.
311
+ """
312
+ pass
313
+
314
+
298
315
def bipartition_tree (
299
316
graph : nx .Graph ,
300
317
pop_col : str ,
@@ -306,7 +323,8 @@ def bipartition_tree(
306
323
weight_dict : Optional [Dict ] = None ,
307
324
balance_edge_fn : Callable = find_balanced_edge_cuts_memoization ,
308
325
choice : Callable = random .choice ,
309
- max_attempts : Optional [int ] = 10000
326
+ max_attempts : Optional [int ] = 10000 ,
327
+ allow_pair_reselection : bool = False
310
328
) -> Set :
311
329
"""
312
330
This function finds a balanced 2 partition of a graph by drawing a
@@ -347,10 +365,15 @@ def bipartition_tree(
347
365
:param max_attempts: The maximum number of attempts that should be made to bipartition.
348
366
Defaults to 1000.
349
367
:type max_attempts: Optional[int], optional
368
+ :param allow_pair_reselection: Whether we would like to return an error to the calling
369
+ function to ask it to reselect the pair of nodes to try and recombine. Defaults to False.
370
+ :type allow_pair_reselection: bool, optional
350
371
351
372
:returns: A subset of nodes of ``graph`` (whose induced subgraph is connected). The other
352
373
part of the partition is the complement of this subset.
353
374
:rtype: Set
375
+
376
+ :raises BipartitionWarning: If a possible cut cannot be found after 50 attempts.
354
377
:raises RuntimeError: If a possible cut cannot be found after the maximum number of attempts.
355
378
"""
356
379
# Try to add the region-aware in if the spanning_tree_fn accepts a weight dictionary
@@ -378,6 +401,17 @@ def bipartition_tree(
378
401
restarts += 1
379
402
attempts += 1
380
403
404
+ if attempts == 50 and not allow_pair_reselection :
405
+ warnings .warn ("Failed to find a balanced cut after 50 attempts.\n "
406
+ "Consider running with the parameter\n "
407
+ "allow_pair_reselection=True to allow the algorithm to\n "
408
+ "select a different pair of nodes to try an recombine." ,
409
+ BipartitionWarning )
410
+
411
+ if allow_pair_reselection :
412
+ raise ReselectException (f"Failed to find a balanced cut after { max_attempts } attempts.\n "
413
+ f"Selecting a new district pair" )
414
+
381
415
raise RuntimeError (f"Could not find a possible cut after { max_attempts } attempts." )
382
416
383
417
@@ -589,13 +623,17 @@ def recursive_tree_part(
589
623
min_pop = max (pop_target * (1 - epsilon ), pop_target * (1 - epsilon ) - debt )
590
624
max_pop = min (pop_target * (1 + epsilon ), pop_target * (1 + epsilon ) - debt )
591
625
new_pop_target = (min_pop + max_pop ) / 2
592
- nodes = method (
593
- graph .subgraph (remaining_nodes ),
594
- pop_col = pop_col ,
595
- pop_target = new_pop_target ,
596
- epsilon = (max_pop - min_pop ) / (2 * new_pop_target ),
597
- node_repeats = node_repeats ,
598
- )
626
+
627
+ try :
628
+ nodes = method (
629
+ graph .subgraph (remaining_nodes ),
630
+ pop_col = pop_col ,
631
+ pop_target = new_pop_target ,
632
+ epsilon = (max_pop - min_pop ) / (2 * new_pop_target ),
633
+ node_repeats = node_repeats ,
634
+ )
635
+ except Exception :
636
+ raise
599
637
600
638
if nodes is None :
601
639
raise BalanceError ()
0 commit comments